File size: 4,983 Bytes
95167e7
 
 
 
c02118d
95167e7
 
 
 
 
 
c02118d
 
677e3b7
 
b155bbd
5510202
95167e7
 
 
 
 
 
 
 
 
 
 
 
e864faa
c02118d
37ba605
 
0299602
120bd05
04320ba
95167e7
 
 
 
04320ba
 
 
 
 
95167e7
 
c02118d
677e3b7
95167e7
 
c02118d
95167e7
 
 
 
 
 
 
 
 
 
 
 
 
c02118d
 
 
 
 
0094eb2
57d5907
0299602
c02118d
5510202
e864faa
95167e7
b05cdf7
444efa4
0299602
 
95167e7
 
 
0299602
95167e7
 
120bd05
95167e7
 
 
f7d5ff7
5510202
0d2c418
5510202
95167e7
 
0299602
95167e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
677e3b7
b05cdf7
95167e7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
from huggingface_hub import Repository, InferenceClient
import os
import json
import re

API_TOKEN = os.environ.get("API_TOKEN")
API_ENDPOINT = os.environ.get("API_ENDPOINT")

KEY = os.environ.get("KEY")

SPECIAL_SYMBOLS = ["‹", "›"]

DEFAULT_INPUT = f"You: Hi!"
DEFAULT_PREOUTPUT = f"AI: "
DEFAULT_INSTRUCTION = "You are an helpful chatbot."

API_ENDPOINTS = {
    "Falcon": "tiiuae/falcon-180B-chat",
    "Llama": "meta-llama/Llama-2-70b-chat-hf"
}

CHOICES = []
CLIENTS = {}

for model_name, model_endpoint in API_ENDPOINTS.items():
    CHOICES.append(model_name)
    CLIENTS[model_name] = InferenceClient(model_endpoint, headers = { "Authorization": f"Bearer {API_TOKEN}" })

def format(instruction, history, input, preoutput):
    sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
    formatted_history = "".join(f"{sy_l}{message[0]}{sy_r}\n{sy_l}{message[1]}{sy_r}\n" for message in history)
    formatted_input = f"{sy_l}System: {instruction}{sy_r}\n{formatted_history}{sy_l}{input}{sy_r}\n{sy_l}{preoutput}"
    return formatted_input
    
def predict(instruction, history, input, preoutput, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed):

    if (access_key != KEY):
        print(">>> MODEL FAILED: Input: " + input + ", Attempted Key: " + access_key)
        return ("[UNAUTHORIZED ACCESS]", input);

    instruction = instruction or DEFAULT_INSTRUCTION
    history = history or []
    input = input or ""
    preoutput = preoutput or ""
        
    stops = json.loads(stop_seqs)

    formatted_input = format(instruction, history, input, preoutput)
    
    response = CLIENTS[model].text_generation(
        formatted_input,
        temperature = temperature,
        max_new_tokens = max_tokens,
        top_p = top_p,
        top_k = top_k,
        repetition_penalty = rep_p,
        stop_sequences = stops,
        do_sample = True,
        seed = seed,
        stream = False,
        details = False,
        return_full_text = False
    )

    sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
    pre_result = f"{sy_l}{response}{sy_r}{''.join(SPECIAL_SYMBOLS)}"
    pattern = re.compile(f"{sy_l}(.*?){sy_r}", re.DOTALL)
    match = pattern.search(pre_result)
    get_result = match.group(1).strip() if match else ""

    history = history + [[input, get_result]]
    
    print(f"---\nUSER: {input}\nBOT: {get_result}\n---")

    return (preoutput + response, input, history)

def clear_history():
    print(">>> HISTORY CLEARED!")
    return []
     
def maintain_cloud():
    print(">>> SPACE MAINTAINED!")
    return ("SUCCESS!", "SUCCESS!")

with gr.Blocks() as demo:
    with gr.Row(variant = "panel"):
        gr.Markdown("🔯 This is a private LLM CHAT Space owned within STC Holdings!\n\n\nhttps://discord.gg/6JRtGawz7B")
            
    with gr.Row():
        with gr.Column():
            history = gr.Chatbot(abel = "History", elem_id = "chatbot")
            input = gr.Textbox(label = "Input", value = DEFAULT_INPUT, lines = 2)
            preoutput = gr.Textbox(label = "Pre-Output", value = DEFAULT_PREOUTPUT, lines = 1)
            instruction = gr.Textbox(label = "Instruction", value = DEFAULT_INSTRUCTION, lines = 4)
            access_key = gr.Textbox(label = "Access Key", lines = 1)
            run = gr.Button("▶")
            clear = gr.Button("🗑️")
            cloud = gr.Button("☁️")
            
        with gr.Column():
            model = gr.Dropdown(choices = CHOICES, value = next(iter(API_ENDPOINTS)), interactive = True, label = "Model")
            temperature = gr.Slider( minimum = 0, maximum = 2, value = 1, step = 0.01, interactive = True, label = "Temperature" )
            top_p = gr.Slider( minimum = 0.01, maximum = 0.99, value = 0.95, step = 0.01, interactive = True, label = "Top P" )
            top_k = gr.Slider( minimum = 1, maximum = 2048, value = 50, step = 1, interactive = True, label = "Top K" )
            rep_p = gr.Slider( minimum = 0.01, maximum = 2, value = 1.2, step = 0.01, interactive = True, label = "Repetition Penalty" )
            max_tokens = gr.Slider( minimum = 1, maximum = 2048, value = 32, step = 64, interactive = True, label = "Max New Tokens" )
            stop_seqs = gr.Textbox(label = "Stop Sequences ( JSON Array / 4 Max )", lines = 1, value = '["‹", "›"]')
            seed = gr.Slider( minimum = 0, maximum = 8192, value = 42, step = 1, interactive = True, label = "Seed" )
            
    with gr.Row():
        with gr.Column():
            output = gr.Textbox(label = "Output", value = "", lines = 50)

    run.click(predict, inputs = [instruction, history, input, preoutput, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed], outputs = [output, input, history])
    clear.click(clear_history, [], history)
    cloud.click(maintain_cloud, inputs = [], outputs = [input, output])
    
demo.queue(concurrency_count = 500, api_open = True).launch(show_api = True)