File size: 3,626 Bytes
63a5c24
9513cae
63a5c24
9513cae
63a5c24
5fb8127
a3e95e6
 
63a5c24
 
9513cae
 
 
43b8dd7
9513cae
 
a3e95e6
 
 
9513cae
a3e95e6
9513cae
 
5fb8127
 
f027a65
 
 
 
 
 
 
 
 
 
 
 
a3e95e6
5fb8127
 
 
63a5c24
8b0e392
5fb8127
f027a65
5fb8127
f027a65
 
a3e95e6
 
63a5c24
5fb8127
af66144
a3e95e6
af66144
a3e95e6
af66144
9513cae
 
 
8b0e392
 
 
 
 
 
 
5fb8127
 
5b981d0
a3e95e6
63a5c24
5fb8127
63a5c24
 
 
 
 
 
 
 
 
 
 
f027a65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fb8127
 
 
f027a65
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import json
from typing import List, Tuple
from collections import OrderedDict

import gradio as gr

from shared import Client


config = json.loads(os.environ['CONFIG'])



clients = {}
for name in config:
    model_personas = config[name].get("personas", {})
    client = Client(
        api_url=os.environ[config[name]['api_url']],
        api_key=os.environ[config[name]['api_key']],
        personas=model_personas
    )
    clients[name] = client


model_names = list(config.keys())
radio_infos = [f"{name} ({clients[name].vllm_model_name})" for name in model_names]
accordion_info = "Config"



def parse_radio_select(radio_select):
    value_index = next(i for i in range(len(radio_select)) if radio_select[i] is not None)
    model = model_names[value_index]
    persona = radio_select[value_index]
    return model, persona



def respond(
    message,
    history: List[Tuple[str, str]],
    conversational,
    max_tokens,
    *radio_select,
):
    model, persona = parse_radio_select(radio_select)

    client = clients[model]

    messages = []

    try:
        system_prompt = client.personas[persona]
    except KeyError:
        supported_personas = list(client.personas.keys())
        raise gr.Error(f"Model '{model}' does not support persona '{persona}', only {supported_personas}")
    if system_prompt is not None:
        messages.append({"role": "system", "content": system_prompt})

    if conversational:
        for val in history[-2:]:
            if val[0]:
                messages.append({"role": "user", "content": val[0]})
            if val[1]:
                messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    completion = client.openai.chat.completions.create(
        model=client.vllm_model_name,
        messages=messages,
        max_tokens=max_tokens,
        temperature=0,
        extra_body={
            "repetition_penalty": 1.05,
            "use_beam_search": True,
            "best_of": 5,
        },
    )
    response = completion.choices[0].message.content
    return response


# Components
radios = [gr.Radio(choices=clients[name].personas.keys(), value=None, label=info) for name, info in zip(model_names, radio_infos)]
radios[0].value = list(clients[model_names[0]].personas.keys())[0]

conversational_checkbox = gr.Checkbox(value=True, label="conversational")
max_tokens_slider = gr.Slider(minimum=64, maximum=2048, value=512, step=64, label="Max new tokens")



with gr.Blocks() as blocks:
    # Events
    radio_state = gr.State([radio.value for radio in radios])
    @gr.on(triggers=[radio.input for radio in radios], inputs=[radio_state, *radios], outputs=[radio_state, *radios])
    def radio_click(state, *new_state):
        changed_index = next(i for i in range(len(state)) if state[i] != new_state[i])
        changed_value = new_state[changed_index]
        clean_state = [None if i != changed_index else changed_value for i in range(len(state))]

        return clean_state, *clean_state

    # Compile
    with gr.Accordion(label=accordion_info, open=True, render=False) as accordion:
        [radio.render() for radio in radios]
        conversational_checkbox.render()
        max_tokens_slider.render()

    demo = gr.ChatInterface(
        respond,
        additional_inputs=[
            conversational_checkbox,
            max_tokens_slider,
            *radios,
        ],
        additional_inputs_accordion=accordion,
        title="NeonLLM (v2024-07-03)",
        concurrency_limit=5,
    )
    accordion.render()


if __name__ == "__main__":
    blocks.launch()