tuxedocat's picture
Use new auth lib
918d142
raw
history blame
No virus
4.55 kB
from functools import partial
import gradio as gr
import httpx
import subprocess
import os
from openai import OpenAI
from cycloud.auth import load_default_credentials
from const import (
LLM_BASE_URL,
AUTH_CMD,
SYSTEM_PROMPTS,
EXAMPLES,
CSS,
HEADER,
FOOTER,
PLACEHOLDER,
ModelInfo,
MODELS,
)
def get_headers(host: str) -> dict:
creds = load_default_credentials()
return {
"Authorization": f"Bearer {creds.access_token}",
"Host": host,
"Accept": "application/json",
"Content-Type": "application/json",
}
def proxy(request: httpx.Request, model_info: ModelInfo) -> httpx.Request:
request.url = request.url.copy_with(path=model_info.endpoint)
request.headers.update(get_headers(host=model_info.host))
return request
def call_llm(
message: str,
history: list[dict],
model_name: str,
system_prompt: str,
max_tokens: int,
temperature: float,
top_p: float,
):
history_openai_format = []
system_prompt_text = SYSTEM_PROMPTS[system_prompt]
if len(history) == 0:
init = {
"role": "system",
"content": system_prompt_text,
}
history_openai_format.append(init)
history_openai_format.append({"role": "user", "content": message})
else:
for human, assistant in history:
history_openai_format.append({"role": "user", "content": human})
history_openai_format.append({"role": "assistant", "content": assistant})
history_openai_format.append({"role": "user", "content": message})
model_info = MODELS[model_name]
client = OpenAI(
api_key="",
base_url=LLM_BASE_URL,
http_client=httpx.Client(
event_hooks={
"request": [partial(proxy, model_info=model_info)],
},
verify=False,
),
)
stream = client.chat.completions.create(
model=f"/data/cyberagent/{model_info.name}",
messages=history_openai_format,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
n=1,
stream=True,
extra_body={"repetition_penalty": 1.1},
)
message = ""
for chunk in stream:
content = chunk.choices[0].delta.content or ""
message = message + content
yield message
def run():
chatbot = gr.Chatbot(
elem_id="chatbot",
scale=1,
show_copy_button=True,
height="70%",
layout="panel",
)
with gr.Blocks(fill_height=True) as demo:
gr.Markdown(HEADER)
gr.ChatInterface(
fn=call_llm,
stop_btn="Stop Generation",
examples=EXAMPLES,
cache_examples=False,
multimodal=False,
chatbot=chatbot,
additional_inputs_accordion=gr.Accordion(
label="Parameters", open=False, render=False
),
additional_inputs=[
gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Model",
visible=False,
render=False,
),
gr.Dropdown(
choices=list(SYSTEM_PROMPTS.keys()),
value=list(SYSTEM_PROMPTS.keys())[0],
label="System Prompt",
visible=False,
render=False,
),
gr.Slider(
minimum=1,
maximum=4096,
step=1,
value=1024,
label="Max tokens",
visible=True,
render=False,
),
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.3,
label="Temperature",
visible=True,
render=False,
),
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=1.0,
label="Top-p",
visible=True,
render=False,
),
],
analytics_enabled=False,
)
gr.Markdown(FOOTER)
demo.queue(max_size=256, api_open=False)
demo.launch(share=False, quiet=True)
if __name__ == "__main__":
run()