File size: 5,465 Bytes
95167e7
e8f3085
95167e7
 
c02118d
95167e7
 
 
 
 
2a52702
 
c02118d
514b2f9
 
d1f0e94
5510202
2a52702
1a1d724
95167e7
 
80dce94
 
7f42d00
6e243f5
b09a184
 
53af226
95167e7
 
 
 
 
 
 
 
 
514b2f9
2a52702
 
e2ad91e
2a52702
94c41a5
e2ad91e
120bd05
514b2f9
c3fca4f
95167e7
 
d762b96
04320ba
 
 
 
514b2f9
1a1d724
95167e7
 
c02118d
514b2f9
69d55a1
cc6ed2c
25f48f1
69d55a1
 
95167e7
c02118d
95167e7
 
 
 
 
 
 
 
 
 
 
 
 
514b2f9
7822b29
 
df6c8dd
7822b29
df6c8dd
d762b96
5bd743c
8465750
7822b29
5510202
7822b29
95167e7
b05cdf7
444efa4
0299602
 
95167e7
 
 
0299602
95167e7
 
a52e921
95167e7
 
 
7a9ac91
5510202
514b2f9
5510202
95167e7
 
0299602
95167e7
 
 
e458a61
95167e7
 
 
 
 
c3fca4f
b14f3c1
95167e7
 
 
 
 
d4bf10a
3ca2ec3
 
95167e7
39f1f96
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import gradio as gr
from huggingface_hub import InferenceClient
import os
import json
import re

API_TOKEN = os.environ.get("API_TOKEN")

KEY = os.environ.get("KEY")

SPECIAL_SYMBOLS_AI = ["ㅤ", "ㅤ"]
SPECIAL_SYMBOLS_USER = ["⠀", "⠀"] # ["‹", "›"] ['"', '"']

DEFAULT_INPUT = "User: Hi!"
DEFAULT_WRAP = "Statical: %s"
DEFAULT_INSTRUCTION = "Conversation: Statical is a helpful chatbot who is communicating with people."

DEFAULT_STOPS = '["ㅤ", "⠀"]' # '["‹", "›"]' '[\"\\\"\"]'

API_ENDPOINTS = {
    "Falcon": "tiiuae/falcon-180B-chat",
    "Llama": "meta-llama/Llama-2-70b-chat-hf",
    "Mistral": "mistralai/Mistral-7B-v0.1",
    "Mistral*": "mistralai/Mistral-7B-Instruct-v0.1",
    "Open-3.5": "openchat/openchat_3.5",
    "Xistral": "mistralai/Mixtral-8x7B-v0.1",
    "Xistral*": "mistralai/Mixtral-8x7B-Instruct-v0.1",
    "Hermes": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
}

CHOICES = []
CLIENTS = {}

for model_name, model_endpoint in API_ENDPOINTS.items():
    CHOICES.append(model_name)
    CLIENTS[model_name] = InferenceClient(model_endpoint, headers = { "Authorization": f"Bearer {API_TOKEN}" })

def format(instruction, history, input, wrap):
    sy_la, sy_ra = SPECIAL_SYMBOLS_AI[0], SPECIAL_SYMBOLS_AI[1]
    sy_l, sy_r = SPECIAL_SYMBOLS_USER[0], SPECIAL_SYMBOLS_USER[1]
    wrapped_input = wrap % ("")
    formatted_history = "".join(f"{sy_l}{message[0]}{sy_r}{sy_la}{message[1]}{sy_la}" for message in history)
    formatted_input = f"{sy_la}{instruction}{sy_ra}{formatted_history}{sy_l}{input}{sy_r}{sy_la}"
    return f"{formatted_input}{wrapped_input}", formatted_input
    
def predict(access_key, instruction, history, input, wrap, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed):

    if (access_key != KEY):
        print(">>> MODEL FAILED: Input: " + input + ", Attempted Key: " + access_key)
        return ("[UNAUTHORIZED ACCESS]", input, []);

    instruction = instruction or DEFAULT_INSTRUCTION
    history = history or []
    input = input or ""
    wrap = wrap or ""
    stop_seqs = stop_seqs or DEFAULT_STOPS
        
    stops = json.loads(stop_seqs)

    formatted_input, formatted_input_base = format(instruction, history, input, wrap)
    
    print(seed)
    print(formatted_input)
    print(model)

    response = CLIENTS[model].text_generation(
        formatted_input,
        temperature = temperature,
        max_new_tokens = max_tokens,
        top_p = top_p,
        top_k = top_k,
        repetition_penalty = rep_p,
        stop_sequences = stops,
        do_sample = True,
        seed = seed,
        stream = False,
        details = False,
        return_full_text = False
    )

    result = wrap % (response)
    
    for stop in stops:
        result = result.split(stop, 1)[0]
    for symbol in stops:
        result = result.replace(symbol, '')

    history = history + [[input, result]]

    print(f"---\nUSER: {input}\nBOT: {result}\n---")

    return (result, input, history)

def clear_history():
    print(">>> HISTORY CLEARED!")
    return []
     
def maintain_cloud():
    print(">>> SPACE MAINTAINED!")
    return ("SUCCESS!", "SUCCESS!")

with gr.Blocks() as demo:
    with gr.Row(variant = "panel"):
        gr.Markdown("✡️ This is a private LLM Space owned within STC Holdings!")
            
    with gr.Row():
        with gr.Column():
            history = gr.Chatbot(label = "History", elem_id = "chatbot")
            input = gr.Textbox(label = "Input", value = DEFAULT_INPUT, lines = 2)
            wrap = gr.Textbox(label = "Wrap", value = DEFAULT_WRAP, lines = 1)
            instruction = gr.Textbox(label = "Instruction", value = DEFAULT_INSTRUCTION, lines = 4)
            access_key = gr.Textbox(label = "Access Key", lines = 1)
            run = gr.Button("▶")
            clear = gr.Button("🗑️")
            cloud = gr.Button("☁️")
            
        with gr.Column():
            model = gr.Dropdown(choices = CHOICES, value = next(iter(API_ENDPOINTS)), interactive = True, label = "Model")
            temperature = gr.Slider( minimum = 0, maximum = 2, value = 1, step = 0.01, interactive = True, label = "Temperature" )
            top_p = gr.Slider( minimum = 0.01, maximum = 0.99, value = 0.95, step = 0.01, interactive = True, label = "Top P" )
            top_k = gr.Slider( minimum = 1, maximum = 2048, value = 50, step = 1, interactive = True, label = "Top K" )
            rep_p = gr.Slider( minimum = 0.01, maximum = 2, value = 1.2, step = 0.01, interactive = True, label = "Repetition Penalty" )
            max_tokens = gr.Slider( minimum = 1, maximum = 2048, value = 32, step = 64, interactive = True, label = "Max New Tokens" )
            stop_seqs = gr.Textbox( value = DEFAULT_STOPS, interactive = True, label = "Stop Sequences ( JSON Array / 4 Max )" )
            seed = gr.Slider( minimum = 0, maximum = 9007199254740991, value = 42, step = 1, interactive = True, label = "Seed" )
            
    with gr.Row():
        with gr.Column():
            output = gr.Textbox(label = "Output", value = "", lines = 50)

    run.click(predict, inputs = [access_key, instruction, history, input, wrap, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed], outputs = [output, input, history], queue = False)
    clear.click(clear_history, [], history, queue = False)
    cloud.click(maintain_cloud, inputs = [], outputs = [input, output], queue = False)
    
demo.launch(show_api = True)