Spaces:
Running
Running
import os | |
import json | |
from typing import List, Tuple | |
from collections import OrderedDict | |
import gradio as gr | |
from shared import Client | |
config = json.loads(os.environ['CONFIG']) | |
clients = {} | |
for name in config: | |
model_personas = config[name].get("personas", {}) | |
client = Client( | |
api_url=os.environ[config[name]['api_url']], | |
api_key=os.environ[config[name]['api_key']], | |
personas=model_personas | |
) | |
clients[name] = client | |
model_names = list(config.keys()) | |
radio_infos = [f"{name} ({clients[name].vllm_model_name})" for name in model_names] | |
accordion_info = "Config" | |
def parse_radio_select(radio_select): | |
value_index = next(i for i in range(len(radio_select)) if radio_select[i] is not None) | |
model = model_names[value_index] | |
persona = radio_select[value_index] | |
return model, persona | |
def respond( | |
message, | |
history: List[Tuple[str, str]], | |
conversational, | |
max_tokens, | |
*radio_select, | |
): | |
model, persona = parse_radio_select(radio_select) | |
client = clients[model] | |
messages = [] | |
try: | |
system_prompt = client.personas[persona] | |
except KeyError: | |
supported_personas = list(client.personas.keys()) | |
raise gr.Error(f"Model '{model}' does not support persona '{persona}', only {supported_personas}") | |
if system_prompt is not None: | |
messages.append({"role": "system", "content": system_prompt}) | |
if conversational: | |
for val in history[-2:]: | |
if val[0]: | |
messages.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
messages.append({"role": "user", "content": message}) | |
completion = client.openai.chat.completions.create( | |
model=client.vllm_model_name, | |
messages=messages, | |
max_tokens=max_tokens, | |
temperature=0, | |
extra_body={ | |
"repetition_penalty": 1.05, | |
"use_beam_search": True, | |
"best_of": 5, | |
}, | |
) | |
response = completion.choices[0].message.content | |
return response | |
# Components | |
radios = [gr.Radio(choices=clients[name].personas.keys(), value=None, label=info) for name, info in zip(model_names, radio_infos)] | |
radios[0].value = list(clients[model_names[0]].personas.keys())[0] | |
conversational_checkbox = gr.Checkbox(value=True, label="conversational") | |
max_tokens_slider = gr.Slider(minimum=64, maximum=2048, value=512, step=64, label="Max new tokens") | |
with gr.Blocks() as blocks: | |
# Events | |
radio_state = gr.State([radio.value for radio in radios]) | |
def radio_click(state, *new_state): | |
changed_index = next(i for i in range(len(state)) if state[i] != new_state[i]) | |
changed_value = new_state[changed_index] | |
clean_state = [None if i != changed_index else changed_value for i in range(len(state))] | |
return clean_state, *clean_state | |
# Compile | |
with gr.Accordion(label=accordion_info, open=True, render=False) as accordion: | |
[radio.render() for radio in radios] | |
conversational_checkbox.render() | |
max_tokens_slider.render() | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
conversational_checkbox, | |
max_tokens_slider, | |
*radios, | |
], | |
additional_inputs_accordion=accordion, | |
title="NeonLLM (v2024-07-03)", | |
concurrency_limit=5, | |
) | |
accordion.render() | |
if __name__ == "__main__": | |
blocks.launch() |