BaseChat / app_single.py
yuchenlin's picture
side by side
d8f6559
raw
history blame
4.63 kB
import gradio as gr
import os
from typing import List
import logging
import urllib.request
from utils import model_name_mapping, urial_template, openai_base_request
from constant import js_code_label, HEADER_MD
from openai import OpenAI
import datetime
# add logging info to console
logging.basicConfig(level=logging.INFO)
URIAL_VERSION = "inst_1k_v4.help"
URIAL_URL = f"https://raw.githubusercontent.com/Re-Align/URIAL/main/urial_prompts/{URIAL_VERSION}.txt"
urial_prompt = urllib.request.urlopen(URIAL_URL).read().decode('utf-8')
urial_prompt = urial_prompt.replace("```", '"""') # new version of URIAL uses """ instead of ```
STOP_STRS = ['"""', '# Query:', '# Answer:']
addr_limit_counter = {}
LAST_UPDATE_TIME = datetime.datetime.now()
def respond(
message,
history: list[tuple[str, str]],
max_tokens,
temperature,
top_p,
rp,
model_name,
api_key,
request:gr.Request
):
global STOP_STRS, urial_prompt, LAST_UPDATE_TIME, addr_limit_counter
rp = 1.0
prompt = urial_template(urial_prompt, history, message)
# _model_name = "meta-llama/Llama-3-8b-hf"
_model_name = model_name_mapping(model_name)
if api_key and len(api_key) == 64:
api_key = api_key
else:
api_key = None
# headers = request.headers
# if already 24 hours passed, reset the counter
if datetime.datetime.now() - LAST_UPDATE_TIME > datetime.timedelta(days=1):
addr_limit_counter = {}
LAST_UPDATE_TIME = datetime.datetime.now()
host_addr = request.client.host
if host_addr not in addr_limit_counter:
addr_limit_counter[host_addr] = 0
if addr_limit_counter[host_addr] > 100:
return "You have reached the limit of 100 requests for today. Please use your own API key."
infer_request = openai_base_request(prompt=prompt, model=_model_name,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
repetition_penalty=rp,
stop=STOP_STRS, api_key=api_key)
addr_limit_counter[host_addr] += 1
logging.info(f"Requesting chat completion from OpenAI API with model {_model_name}")
logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};")
response = ""
for msg in infer_request:
# print(msg.choices[0].delta.keys())
if hasattr(msg.choices[0], "delta"):
token = msg.choices[0].delta["content"]
else:
token = msg.choices[0].text
should_stop = False
for _stop in STOP_STRS:
if _stop in response + token:
should_stop = True
break
if should_stop:
break
response += token
if response.endswith('\n"'):
response = response[:-1]
elif response.endswith('\n""'):
response = response[:-2]
yield response
with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo:
with gr.Row():
with gr.Column():
gr.Markdown(HEADER_MD)
model_name = gr.Radio(["Llama-3.1-405B-FP8", "Llama-3-70B", "Llama-3-8B",
"Mistral-7B-v0.1",
"Mixtral-8x22B", "Qwen1.5-72B", "Yi-34B", "Llama-2-7B", "Llama-2-70B", "OLMO"]
, value="Llama-3.1-405B-FP8", label="Base LLM name")
with gr.Column():
api_key = gr.Textbox(label="πŸ”‘ APIKey", placeholder="Enter your Together/Hyperbolic API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key", visible=False)
# with gr.Column():
with gr.Accordion("βš™οΈ Parameters for Base LLM", open=True):
with gr.Row():
max_tokens = gr.Textbox(value=256, label="Max tokens")
temperature = gr.Textbox(value=0.5, label="Temperature")
top_p = gr.Textbox(value=0.9, label="Top-p")
rp = gr.Textbox(value=1.1, label="Repetition penalty")
# with gr.Row():
chat = gr.ChatInterface(
respond,
additional_inputs=[max_tokens, temperature, top_p, rp, model_name, api_key],
# additional_inputs_accordion="βš™οΈ Parameters",
# fill_height=True,
)
chat.chatbot.label="Chat with Base LLMs via URIAL"
chat.chatbot.height = 550
chat.chatbot.show_copy_button = True
if __name__ == "__main__":
demo.launch(show_api=False)