|
import os |
|
from threading import Thread |
|
from typing import Iterator |
|
import json |
|
|
|
import gradio as gr |
|
import spaces |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
import subprocess |
|
import copy |
|
|
|
import subprocess |
|
import sys |
|
|
|
def run_command(command): |
|
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) |
|
output, error = process.communicate() |
|
if process.returncode != 0: |
|
print(f"Error executing command: {command}") |
|
print(f"Error message: {error.decode('utf-8')}") |
|
sys.exit(1) |
|
return output.decode('utf-8') |
|
|
|
MAX_MAX_NEW_TOKENS = 2048 |
|
DEFAULT_MAX_NEW_TOKENS = 1024 |
|
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8000")) |
|
model_choices = [ |
|
"rubra-ai/Meta-Llama-3-8B-Instruct", |
|
"rubra-ai/Qwen2-7B-Instruct", |
|
"rubra-ai/Phi-3-mini-128k-instruct", |
|
"rubra-ai/Mistral-7B-Instruct-v0.3", |
|
"rubra-ai/Mistral-7B-Instruct-v0.2", |
|
"rubra-ai/gemma-1.1-2b-it" |
|
] |
|
|
|
DESCRIPTION = """\ |
|
# Rubra v0.1 - Top LLMs enhanced with function (tool) calling |
|
|
|
This is a demo of the Rubra collection of models. You can use the models for general conversation, |
|
task completion, and function calling with the provided tools input. |
|
|
|
""" |
|
|
|
LICENSE = """ |
|
<p/> |
|
|
|
--- |
|
Rubra code is licensed under the Apache License, Version 2.0 (the "License"); |
|
you may not use this file except in compliance with the License. |
|
You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software |
|
distributed under the License is distributed on an "AS IS" BASIS, |
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
See the License for the specific language governing permissions and |
|
limitations under the License. |
|
|
|
Rubra models are licensed under the parent model's license. See the parent model card for more information. |
|
""" |
|
|
|
if not torch.cuda.is_available(): |
|
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" |
|
|
|
if torch.cuda.is_available(): |
|
model_id = "sanjay920/Llama-3-8b-function-calling-alpha-v1" |
|
model = None |
|
tokenizer = None |
|
|
|
def load_model(model_name): |
|
global model, tokenizer |
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_4bit=False) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
tokenizer.use_default_system_prompt = False |
|
model.generation_config.pad_token_id = tokenizer.pad_token_id |
|
|
|
load_model(model_id) |
|
|
|
def is_valid_json(tools: str) -> bool: |
|
try: |
|
json.loads(tools) |
|
return True |
|
except ValueError: |
|
return False |
|
|
|
def validate_tools(tools): |
|
if tools.strip() == "" or is_valid_json(tools): |
|
return gr.update(visible=False) |
|
else: |
|
return gr.update(visible=True) |
|
|
|
def json_to_markdown(json_obj): |
|
"""Convert a JSON object to a formatted markdown string.""" |
|
markdown = "" |
|
for item in json_obj: |
|
if item.get("type") == "text": |
|
|
|
markdown += item.get("text", "") + "\n\n" |
|
elif item.get("type") == "function": |
|
|
|
markdown += "```json\n" |
|
|
|
markdown += json.dumps(item, indent=2) |
|
markdown += "\n```\n\n" |
|
return markdown.strip() |
|
|
|
def user(user_message, history): |
|
return "", history + [[user_message, None]] |
|
|
|
def bot(history, system_prompt, tools, role, max_new_tokens, temperature): |
|
user_message = history[-1][0] |
|
if history[-1][1] is None: |
|
history[-1][1] = "" |
|
|
|
ui_history = list(history) |
|
all_tool_outputs = [] |
|
output_accumulated = "" |
|
|
|
for chunk in generate(user_message, history[:-1], system_prompt, tools, role, max_new_tokens, temperature): |
|
history[-1][1] += chunk |
|
print(history[-1][1]) |
|
|
|
if "endtoolcall" in history[-1][1]: |
|
process_output = postprocess_output(history[-1][1]) |
|
print("process output:\n", process_output) |
|
if process_output: |
|
temp_history = copy.deepcopy(history) |
|
if isinstance(process_output, list) and len(process_output) > 0 and isinstance(process_output[0], dict): |
|
markdown_output = json_to_markdown(process_output) |
|
temp_history[-1][1] = markdown_output |
|
else: |
|
temp_history[-1][1] = str(process_output) |
|
print(temp_history[-1][1]) |
|
print("--------------------------") |
|
yield temp_history |
|
else: |
|
print(history[-1][1]) |
|
print("--------------------------") |
|
yield history |
|
else: |
|
print(history[-1][1]) |
|
print("--------------------------") |
|
yield history |
|
|
|
@spaces.GPU |
|
def generate( |
|
message: str, |
|
chat_history: list[tuple[str, str]], |
|
system_prompt: str, |
|
tools: str, |
|
role: str, |
|
max_new_tokens: int = 1024, |
|
temperature: float = 0.6, |
|
) -> Iterator[str]: |
|
global model, tokenizer |
|
conversation = [] |
|
if system_prompt: |
|
conversation.append({"role": "system", "content": system_prompt}) |
|
for user, assistant in chat_history: |
|
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) |
|
conversation.append({"role": role, "content": message}) |
|
|
|
if tools: |
|
if not is_valid_json(tools): |
|
yield "Invalid JSON in tools. Please correct it." |
|
return |
|
tools = json.loads(tools) |
|
formatted_msgs = preprocess_input(msgs=conversation, tools=tools) |
|
else: |
|
formatted_msgs = conversation |
|
|
|
input_ids = tokenizer.apply_chat_template(formatted_msgs, return_tensors="pt") |
|
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: |
|
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] |
|
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") |
|
input_ids = input_ids.to(model.device) |
|
|
|
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
generate_kwargs = dict( |
|
input_ids=input_ids, |
|
streamer=streamer, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
top_p=0.95, |
|
temperature=temperature, |
|
num_beams=1, |
|
repetition_penalty=1.2, |
|
) |
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
|
|
for text in streamer: |
|
|
|
yield text |
|
|
|
bot_message = """Hello! How can I assist you today? If you have any questions or need information on a specific topic, feel free to ask. I can also utilize `tools` that you input to help you better. For example: |
|
``` |
|
[ |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "get_current_weather", |
|
"description": "Get the current weather in a given location", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"location": { |
|
"type": "string", |
|
"description": "Must include the city AND state, e.g. 'San Francisco, CA'" |
|
}, |
|
"unit": { |
|
"type": "string", |
|
"enum": |
|
["celsius", "fahrenheit"] |
|
} |
|
}, |
|
"required": ["location"] |
|
} |
|
} |
|
} |
|
] |
|
``` |
|
|
|
You can also define `functions` (deprecated in favor of `tools` in OpenAI): |
|
``` |
|
[ |
|
{ |
|
"name": "get_current_date", |
|
"description": "Gets the current date at the given location. Results are in ISO 8601 date format; e.g. 2024-04-25", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"location": { |
|
"type": "string", |
|
"description": "The city and state to get the current date at, e.g. San Francisco, CA" |
|
} |
|
}, |
|
"required":["location"] |
|
} |
|
} |
|
] |
|
``` |
|
""" |
|
|
|
def create_chat_interface(): |
|
with gr.Blocks(css="style.css") as demo: |
|
gr.Markdown(DESCRIPTION) |
|
|
|
with gr.Row(equal_height=True, elem_id="main-row"): |
|
with gr.Column(scale=3, min_width=500): |
|
|
|
chatbot = gr.Chatbot( |
|
value=[("Hi", bot_message)], |
|
show_copy_button=True, |
|
elem_id="chatbot", |
|
show_label=False, |
|
render_markdown=True, |
|
height="100%", |
|
layout='bubble', |
|
avatar_images=("human.png", "bot.png") |
|
) |
|
|
|
error_box = gr.Markdown(visible=False, elem_id="error-box") |
|
|
|
with gr.Column(scale=2, min_width=300): |
|
model_dropdown = gr.Dropdown( |
|
choices=model_choices, |
|
label="Select Model", |
|
value="sanjay920/Llama-3-8b-function-calling-alpha-v1" |
|
) |
|
model_dropdown.change(load_model, inputs=[model_dropdown]) |
|
|
|
with gr.Accordion("Settings", open=False): |
|
max_new_tokens = gr.Slider( |
|
label="Max new tokens", |
|
minimum=1, |
|
maximum=MAX_MAX_NEW_TOKENS, |
|
step=1, |
|
value=DEFAULT_MAX_NEW_TOKENS, |
|
) |
|
temperature = gr.Slider( |
|
label="Temperature", |
|
minimum=0.0, |
|
maximum=1.2, |
|
step=0.01, |
|
value=0.01, |
|
) |
|
|
|
with gr.Row(): |
|
role = gr.Dropdown(choices=["user", "observation"], value="user", label="Role", scale=4) |
|
system_prompt = gr.Textbox(label="System Prompt", lines=1, info="Optional") |
|
tools = gr.Textbox(label="Tools", lines=1, placeholder="Enter tools in JSON format", info="Optional") |
|
|
|
|
|
|
|
with gr.Row(): |
|
user_input = gr.Textbox( |
|
label="User Input", |
|
placeholder="Type your message here...", |
|
show_label=True, |
|
scale=8 |
|
) |
|
submit_btn = gr.Button("Submit", variant="primary", elem_id="submit-button") |
|
clear_btn = gr.Button("Clear Conversation", elem_id="clear-button") |
|
|
|
|
|
tools.change(validate_tools, tools, error_box) |
|
|
|
submit_btn.click( |
|
user, |
|
[user_input, chatbot], |
|
[user_input, chatbot], |
|
queue=False |
|
).then( |
|
bot, |
|
[chatbot, system_prompt, tools, role, max_new_tokens, temperature], |
|
chatbot |
|
) |
|
|
|
clear_btn.click(lambda: ([], None), outputs=[chatbot, error_box]) |
|
|
|
gr.Markdown(LICENSE) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
|
|
if not os.path.exists('package.json'): |
|
print("Initializing npm project...") |
|
run_command("npm init -y") |
|
|
|
|
|
print("Installing jsonrepair...") |
|
run_command("npm install jsonrepair") |
|
|
|
|
|
print("Verifying jsonrepair installation:") |
|
run_command("npm list jsonrepair") |
|
|
|
|
|
os.environ['PATH'] = f"{os.path.join(os.getcwd(), 'node_modules', '.bin')}:{os.environ['PATH']}" |
|
from preprocess import preprocess_input |
|
from postprocess import postprocess_output |
|
demo = create_chat_interface() |
|
demo.queue(max_size=20).launch() |
|
|