File size: 5,271 Bytes
95167e7
 
 
 
c02118d
95167e7
 
 
 
 
 
e0842ef
c02118d
caf7748
 
 
5510202
f26090a
1a1d724
95167e7
 
80dce94
 
 
 
95167e7
 
 
 
 
 
 
 
 
e864faa
c02118d
37ba605
3bb6f14
f9662b0
120bd05
7b5cfb2
95167e7
 
 
d762b96
04320ba
 
 
 
 
1a1d724
95167e7
 
c02118d
7822b29
cc6ed2c
25f48f1
95167e7
c02118d
95167e7
 
 
 
 
 
 
 
 
 
 
 
 
c02118d
25f48f1
7822b29
 
df6c8dd
7822b29
df6c8dd
d762b96
5bd743c
8465750
7822b29
5510202
7822b29
95167e7
b05cdf7
444efa4
0299602
 
95167e7
 
 
0299602
95167e7
 
120bd05
95167e7
 
 
f7d5ff7
5510202
0d2c418
5510202
95167e7
 
0299602
95167e7
 
 
 
 
 
 
 
 
1a1d724
b14f3c1
95167e7
 
 
 
 
7b5cfb2
b05cdf7
95167e7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
from huggingface_hub import Repository, InferenceClient
import os
import json
import re

API_TOKEN = os.environ.get("API_TOKEN")
API_ENDPOINT = os.environ.get("API_ENDPOINT")

KEY = os.environ.get("KEY")

SPECIAL_SYMBOLS = ['"', '"'] # ["‹", "›"]

DEFAULT_INPUT = f"User: Hi!"
DEFAULT_PREOUTPUT = f"Statical: "
DEFAULT_INSTRUCTION = "Statical is a helpful chatbot who is communicating with people."

DEFAULT_STOPS = '[\"\\\"\"]'

API_ENDPOINTS = {
    "Falcon": "tiiuae/falcon-180B-chat",
    "Llama": "meta-llama/Llama-2-70b-chat-hf",
    "Qwen": "Qwen/Qwen-14B-Chat",
    "Mistral": "mistralai/Mistral-7B-v0.1",
    "Mistral2": "mistralai/Mistral-7B-Instruct-v0.1",
}

CHOICES = []
CLIENTS = {}

for model_name, model_endpoint in API_ENDPOINTS.items():
    CHOICES.append(model_name)
    CLIENTS[model_name] = InferenceClient(model_endpoint, headers = { "Authorization": f"Bearer {API_TOKEN}" })

def format(instruction, history, input, preoutput):
    sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
    formatted_history = "".join(f"{sy_l}{message[0]}{sy_r}\n{sy_l}{message[1]}{sy_r}\n" for message in history)
    formatted_input = f"{sy_l}INSTRUCTIONS: {instruction}{sy_r}\n{formatted_history}{sy_l}{input}{sy_r}\n{sy_l}"
    return f"{formatted_input}{preoutput}", formatted_input
    
def predict(access_key, instruction, history, input, preoutput, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed):

    if (access_key != KEY):
        print(">>> MODEL FAILED: Input: " + input + ", Attempted Key: " + access_key)
        return ("[UNAUTHORIZED ACCESS]", input, []);

    instruction = instruction or DEFAULT_INSTRUCTION
    history = history or []
    input = input or ""
    preoutput = preoutput or ""
    stop_seqs = stop_seqs or DEFAULT_STOPS
        
    stops = json.loads(stop_seqs)

    formatted_input, formatted_input_base = format(instruction, history, input, preoutput)
    print(seed)
    print(formatted_input)
    response = CLIENTS[model].text_generation(
        formatted_input,
        temperature = temperature,
        max_new_tokens = max_tokens,
        top_p = top_p,
        top_k = top_k,
        repetition_penalty = rep_p,
        stop_sequences = stops,
        do_sample = True,
        seed = seed,
        stream = False,
        details = False,
        return_full_text = False
    )

    sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
    result = preoutput + response
    
    for stop in stops:
        result = result.split(stop, 1)[0]
    for symbol in stops:
        result = result.replace(symbol, '')

    history = history + [[input, result]]

    print(f"---\nUSER: {input}\nBOT: {result}\n---")

    return (result, input, history)

def clear_history():
    print(">>> HISTORY CLEARED!")
    return []
     
def maintain_cloud():
    print(">>> SPACE MAINTAINED!")
    return ("SUCCESS!", "SUCCESS!")

with gr.Blocks() as demo:
    with gr.Row(variant = "panel"):
        gr.Markdown("🔯 This is a private LLM CHAT Space owned within STC Holdings!\n\n\nhttps://discord.gg/6JRtGawz7B")
            
    with gr.Row():
        with gr.Column():
            history = gr.Chatbot(abel = "History", elem_id = "chatbot")
            input = gr.Textbox(label = "Input", value = DEFAULT_INPUT, lines = 2)
            preoutput = gr.Textbox(label = "Pre-Output", value = DEFAULT_PREOUTPUT, lines = 1)
            instruction = gr.Textbox(label = "Instruction", value = DEFAULT_INSTRUCTION, lines = 4)
            access_key = gr.Textbox(label = "Access Key", lines = 1)
            run = gr.Button("▶")
            clear = gr.Button("🗑️")
            cloud = gr.Button("☁️")
            
        with gr.Column():
            model = gr.Dropdown(choices = CHOICES, value = next(iter(API_ENDPOINTS)), interactive = True, label = "Model")
            temperature = gr.Slider( minimum = 0, maximum = 2, value = 1, step = 0.01, interactive = True, label = "Temperature" )
            top_p = gr.Slider( minimum = 0.01, maximum = 0.99, value = 0.95, step = 0.01, interactive = True, label = "Top P" )
            top_k = gr.Slider( minimum = 1, maximum = 2048, value = 50, step = 1, interactive = True, label = "Top K" )
            rep_p = gr.Slider( minimum = 0.01, maximum = 2, value = 1.2, step = 0.01, interactive = True, label = "Repetition Penalty" )
            max_tokens = gr.Slider( minimum = 1, maximum = 2048, value = 32, step = 64, interactive = True, label = "Max New Tokens" )
            stop_seqs = gr.Textbox(label = "Stop Sequences ( JSON Array / 4 Max )", lines = 1, value = DEFAULT_STOPS )
            seed = gr.Slider( minimum = 0, maximum = 9007199254740991, value = 42, step = 1, interactive = True, label = "Seed" )
            
    with gr.Row():
        with gr.Column():
            output = gr.Textbox(label = "Output", value = "", lines = 50)

    run.click(predict, inputs = [access_key, instruction, history, input, preoutput, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed], outputs = [output, input, history])
    clear.click(clear_history, [], history)
    cloud.click(maintain_cloud, inputs = [], outputs = [input, output])
    
demo.queue(concurrency_count = 500, api_open = True).launch(show_api = True)