import gradio as gr from openai import OpenAI import os from typing import List import logging # add logging info to console logging.basicConfig(level=logging.INFO) BASE_URL = "https://api.together.xyz/v1" DEFAULT_API_KEY = os.getenv("TOGETHER_API_KEY") import urllib.request URIAL_VERSION = "inst_1k_v4.help" urial_url = f"https://raw.githubusercontent.com/Re-Align/URIAL/main/urial_prompts/{URIAL_VERSION}.txt" urial_prompt = urllib.request.urlopen(urial_url).read().decode('utf-8') urial_prompt = urial_prompt.replace("```", '"""') stop_str = ['"""', '# Query:', '# Answer:'] def urial_template(urial_prompt, history, message): current_prompt = urial_prompt + "\n" for user_msg, ai_msg in history: current_prompt += f'# Query:\n"""\n{user_msg}\n"""\n\n# Answer:\n"""\n{ai_msg}\n"""\n\n' current_prompt += f'# Query:\n"""\n{message}\n"""\n\n# Answer:\n"""\n' return current_prompt def openai_base_request( model: str=None, temperature: float=0, max_tokens: int=512, top_p: float=1.0, prompt: str=None, n: int=1, repetition_penalty: float=1.0, stop: List[str]=None, api_key: str=None, ): if api_key is None: api_key = DEFAULT_API_KEY client = OpenAI(api_key=api_key, base_url=BASE_URL) # print(f"Requesting chat completion from OpenAI API with model {model}") logging.info(f"Requesting chat completion from OpenAI API with model {model}") logging.info(f"Prompt: {prompt}") logging.info(f"Temperature: {temperature}") logging.info(f"Max tokens: {max_tokens}") logging.info(f"Top-p: {top_p}") logging.info(f"Repetition penalty: {repetition_penalty}") logging.info(f"Stop: {stop}") request = client.completions.create( model=model, prompt=prompt, temperature=float(temperature), max_tokens=int(max_tokens), top_p=float(top_p), n=n, extra_body={'repetition_penalty': float(repetition_penalty)}, stop=stop, stream=True ) return request def respond( message, history: list[tuple[str, str]], max_tokens, temperature, top_p, rp, model_name, together_api_key ): global stop_str, urial_prompt rp = 1.0 prompt = urial_template(urial_prompt, history, message) if model_name == "Llama-3-8B": _model_name = "meta-llama/Llama-3-8b-hf" elif model_name == "Llama-3-70B": _model_name = "meta-llama/Llama-3-70b-hf" else: raise ValueError("Invalid model name") # _model_name = "meta-llama/Llama-3-8b-hf" if together_api_key and len(together_api_key) == 64: api_key = together_api_key else: api_key = DEFAULT_API_KEY request = openai_base_request(prompt=prompt, model=_model_name, temperature=temperature, max_tokens=max_tokens, top_p=top_p, repetition_penalty=rp, stop=stop_str, api_key=api_key) response = "" for msg in request: # print(msg.choices[0].delta.keys()) token = msg.choices[0].delta["content"] response += token should_stop = False for _stop in stop_str: if _stop in response: should_stop = True break if should_stop: break yield response with gr.Blocks() as demo: with gr.Row(): with gr.Column(): gr.Label("Welcome to the URIAL Chatbot!") model_name = gr.Radio(["Llama-3-8B", "Llama-3-70B"], value="Llama-3-8B", label="Base model name") together_api_key = gr.Textbox(label="Together API Key", placeholder="Enter your Together API Key. Leave it blank if you want to use the default API key.", type="password") with gr.Column(): with gr.Column(): with gr.Row(): max_tokens = gr.Textbox(value=1024, label="Max tokens") temperature = gr.Textbox(value=0.5, label="Temperature") with gr.Column(): with gr.Row(): top_p = gr.Textbox(value=0.9, label="Top-p") rp = gr.Textbox(value=1.1, label="Repetition penalty") chat = gr.ChatInterface( respond, additional_inputs=[max_tokens, temperature, top_p, rp, model_name, together_api_key], # additional_inputs_accordion="⚙️ Parameters", # fill_height=True, ) chat.chatbot.height = 600 if __name__ == "__main__": demo.launch()