Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from threading import Thread | |
from typing import Iterator | |
import json | |
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import subprocess | |
import copy | |
import subprocess | |
import sys | |
def run_command(command): | |
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) | |
output, error = process.communicate() | |
if process.returncode != 0: | |
print(f"Error executing command: {command}") | |
print(f"Error message: {error.decode('utf-8')}") | |
sys.exit(1) | |
return output.decode('utf-8') | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 1024 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8000")) | |
model_choices = [ | |
"rubra-ai/Meta-Llama-3-8B-Instruct", | |
"rubra-ai/Qwen2-7B-Instruct", | |
"rubra-ai/Phi-3-mini-128k-instruct", | |
"rubra-ai/Mistral-7B-Instruct-v0.3", | |
# "rubra-ai/Mistral-7B-Instruct-v0.2", | |
# "rubra-ai/gemma-1.1-2b-it" | |
] | |
DESCRIPTION = """\ | |
# Rubra v0.1 - A Collection of Tool (Function) Calling LLMs | |
This is a demo of the Rubra collection of models. You can use the models for general conversation, | |
task completion, and function calling with the provided tools input. | |
See more at https://docs.rubra.ai/ & https://github.com/rubra-ai/rubra | |
""" | |
model_table = """ | |
<p/> | |
--- | |
## Rubra Benchmarks | |
| Model | Params (in billions) | Function Calling | MMLU (5-shot) | GPQA (0-shot) | GSM-8K (8-shot, CoT) | MATH (4-shot, CoT) | MT-bench | | |
|------------------------------------------|----------------------|------------------|---------------|---------------|----------------------|--------------------|----------| | |
| GPT-4o | - | 98.57% | - | 53.6 | - | - | - | | |
| Claude-3.5 Sonnet | - | 98.57% | 88.7 | 59.4 | - | - | - | | |
| Rubra Llama-3 70B Instruct | 70.6 | 97.85% | 75.90 | 33.93 | 82.26 | 34.24 | 8.36 | | |
| Rubra Llama-3 8B Instruct | 8.9 | 89.28% | 64.39 | 31.70 | 68.99 | 23.76 | 8.03 | | |
| Rubra Qwen2-7B-Instruct | 8.55 | 85.71% | 68.88 | 30.36 | 75.82 | 28.72 | 8.08 | | |
| Rubra Mistral 7B Instruct v0.3 | 8.12 | 73.57% | 59.12 | 29.91 | 43.29 | 11.14 | 7.69 | | |
| Rubra Phi-3 Mini 128k Instruct | 4.73 | 70.00% | 67.87 | 29.69 | 79.45 | 30.80 | 8.21 | | |
| Rubra Mistral 7B Instruct v0.2 | 8.11 | 69.28% | 58.90 | 29.91 | 34.12 | 8.36 | 7.36 | | |
| meetkai/functionary-small-v2.5 | 8.03 | 57.14% | 63.92 | 32.14 | 66.11 | 20.54 | 7.09 | | |
| Nexusflow/NexusRaven-V2-13B | 13.0 | 53.75% ∔ | 43.23 | 28.79 | 22.67 | 7.12 | 5.36 | | |
| Mistral Large (closed-source) | - | 48.60% | - | - | 91.21 | 45.0 | - | | |
| Rubra Gemma-1.1 2B Instruct | 2.84 | 45.00% | 38.85 | 24.55 | 6.14 | 2.38 | 5.75 | | |
| meetkai/functionary-medium-v3.0 | 70.6 | 46.43% | 79.85 | 38.39 | 89.54 | 43.02 | 5.49 | | |
| gorilla-llm/gorilla-openfunctions-v2 | 6.91 | 41.25% ∔ | 49.14 | 23.66 | 48.29 | 17.54 | 5.13 | | |
| NousResearch/Hermes-2-Pro-Llama-3-8B | 8.03 | 41.25% | 64.16 | 31.92 | 73.92 | 21.58 | 7.83 | | |
| Mistral 7B Instruct v0.3 | 7.25 | 22.5% | 62.10 | 30.58 | 53.07 | 12.98 | 7.50 | | |
| Gemma-1.1 2B Instruct | 2.51 | - | 37.84 | 22.99 | 6.29 | 6.14 | 5.82 | | |
| Llama-3 8B Instruct | 8.03 | - | 65.69 | 31.47 | 77.41 | 27.58 | 8.07 | | |
| Llama-3 70B Instruct | 70.6 | - | 79.90 | 38.17 | 90.67 | 44.24 | 8.88 | | |
| Mistral 7B Instruct v0.2 | 7.24 | - | 59.27 | 27.68 | 43.21 | 10.30 | 7.50 | | |
| Phi-3 Mini 128k Instruct | 3.82 | - | 69.36 | 27.01 | 83.7 | 32.92 | 8.02 | | |
| Qwen2-7B-Instruct | 7.62 | - | 70.78 | 32.14 | 78.54 | 30.10 | 8.29 | | |
∔ `Nexusflow/NexusRaven-V2-13B` and `gorilla-llm/gorilla-openfunctions-v2` don't accept tool observations, the result of running a tool or function once the LLM calls it, so we appended the observation to the prompt. | |
""" | |
LICENSE = """ | |
<p/> | |
--- | |
Rubra code is licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
Rubra models are licensed under the parent model's license. See the parent model card for more information. | |
""" | |
if not torch.cuda.is_available(): | |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" | |
if torch.cuda.is_available(): | |
model_id = "rubra-ai/Meta-Llama-3-8B-Instruct" # Default model | |
model = None | |
tokenizer = None | |
def load_model(model_name): | |
global model, tokenizer | |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_4bit=False) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model.generation_config.pad_token_id = tokenizer.pad_token_id | |
load_model(model_id) # Load the default model | |
def is_valid_json(tools: str) -> bool: | |
try: | |
json.loads(tools) | |
return True | |
except ValueError: | |
return False | |
def validate_tools(tools): | |
if tools.strip() == "" or is_valid_json(tools): | |
return gr.update(visible=False) | |
else: | |
return gr.update(visible=True) | |
def json_to_markdown(json_obj): | |
"""Convert a JSON object to a formatted markdown string.""" | |
markdown = "" | |
for item in json_obj: | |
if item.get("type") == "text": | |
# For text items, just add the text content | |
markdown += item.get("text", "") + "\n\n" | |
elif item.get("type") == "function": | |
# For function calls, format as JSON | |
markdown += "```json\n" | |
# markdown += json.dumps(item.get("function", {}), indent=2) | |
markdown += json.dumps(item, indent=2) | |
markdown += "\n```\n\n" | |
return markdown.strip() | |
def user(user_message, history): | |
return "", history + [[user_message, None]] | |
def bot(history, system_prompt, tools, role, max_new_tokens, temperature): | |
user_message = history[-1][0] | |
if history[-1][1] is None: | |
history[-1][1] = "" # Ensure it's never None | |
ui_history = list(history) # Clone the history for UI updates | |
all_tool_outputs = [] # Store all processed outputs for final aggregation | |
output_accumulated = "" # To accumulate outputs before processing | |
for chunk in generate(user_message, history[:-1], system_prompt, tools, role, max_new_tokens, temperature): | |
history[-1][1] += chunk | |
print(history[-1][1]) | |
if "endtoolcall" in history[-1][1]: | |
process_output = postprocess_output(history[-1][1]) | |
print("process output:\n", process_output) | |
if process_output: | |
temp_history = copy.deepcopy(history) # Use deepcopy here | |
if isinstance(process_output, list) and len(process_output) > 0 and isinstance(process_output[0], dict): | |
markdown_output = json_to_markdown(process_output) | |
temp_history[-1][1] = markdown_output | |
else: | |
temp_history[-1][1] = str(process_output) | |
print(temp_history[-1][1]) | |
print("--------------------------") | |
yield temp_history | |
else: | |
print(history[-1][1]) | |
print("--------------------------") | |
yield history | |
else: | |
print(history[-1][1]) | |
print("--------------------------") | |
yield history | |
def generate( | |
message: str, | |
chat_history: list[tuple[str, str]], | |
system_prompt: str, | |
tools: str, | |
role: str, | |
max_new_tokens: int = 1024, | |
temperature: float = 0.6, | |
) -> Iterator[str]: | |
global model, tokenizer | |
conversation = [] | |
if system_prompt: | |
conversation.append({"role": "system", "content": system_prompt}) | |
for user, assistant in chat_history: | |
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) | |
conversation.append({"role": role, "content": message}) | |
if tools: | |
if not is_valid_json(tools): | |
yield "Invalid JSON in tools. Please correct it." | |
return | |
tools = json.loads(tools) | |
formatted_msgs = preprocess_input(msgs=conversation, tools=tools) | |
else: | |
formatted_msgs = conversation | |
input_ids = tokenizer.apply_chat_template(formatted_msgs, return_tensors="pt") | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
input_ids = input_ids.to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
input_ids=input_ids, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=0.95, | |
temperature=temperature, | |
num_beams=1, | |
repetition_penalty=1.2, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
for text in streamer: | |
# print("Generated text:", text) | |
yield text | |
bot_message = """Hello! How can I assist you today? If you have any questions or need information on a specific topic, feel free to ask. I can also utilize `tools` that you input to help you better. For example: | |
``` | |
[ | |
{ | |
"type": "function", | |
"function": { | |
"name": "get_stock_information", | |
"description": "Get the current stock market information for a given company", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"ticker_symbol": { | |
"type": "string", | |
"description": "The stock ticker symbol of the company, e.g., 'AAPL' for Apple Inc." | |
}, | |
"exchange": { | |
"type": "string", | |
"description": "The stock exchange where the company is listed, e.g., 'NASDAQ'. If not provided, default to the primary exchange for the ticker symbol." | |
}, | |
"data_type": { | |
"type": "string", | |
"enum": ["price", "volume", "market_cap"], | |
"description": "The type of stock data to retrieve: 'price' for current price, 'volume' for trading volume, 'market_cap' for market capitalization." | |
} | |
}, | |
"required": ["ticker_symbol", "data_type"] | |
} | |
} | |
} | |
] | |
``` | |
You can also define `functions` (deprecated in favor of `tools` in OpenAI): | |
``` | |
[ | |
{ | |
"name": "get_current_date", | |
"description": "Gets the current date at the given location. Results are in ISO 8601 date format; e.g. 2024-04-25", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"location": { | |
"type": "string", | |
"description": "The city and state to get the current date at, e.g. San Francisco, CA" | |
} | |
}, | |
"required":["location"] | |
} | |
} | |
] | |
``` | |
""" | |
def create_chat_interface(): | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(equal_height=True, elem_id="main-row"): | |
with gr.Column(scale=3, min_width=500): | |
# Initialize the chatbot with the welcome message | |
chatbot = gr.Chatbot( | |
value=[("Hi", bot_message)], | |
show_copy_button=True, | |
elem_id="chatbot", | |
show_label=False, | |
render_markdown=True, | |
height="100%", | |
layout='bubble', | |
avatar_images=("human.png", "bot.png") | |
) | |
error_box = gr.Markdown(visible=False, elem_id="error-box") | |
with gr.Column(scale=2, min_width=300): | |
model_dropdown = gr.Dropdown( | |
choices=model_choices, | |
label="Select Model", | |
value="rubra-ai/Meta-Llama-3-8B-Instruct" | |
) | |
model_dropdown.change(load_model, inputs=[model_dropdown]) | |
with gr.Accordion("Settings", open=False): | |
max_new_tokens = gr.Slider( | |
label="Max new tokens", | |
minimum=1, | |
maximum=MAX_MAX_NEW_TOKENS, | |
step=1, | |
value=DEFAULT_MAX_NEW_TOKENS, | |
) | |
temperature = gr.Slider( | |
label="Temperature", | |
minimum=0.0, | |
maximum=1.2, | |
step=0.01, | |
value=0.01, | |
) | |
with gr.Row(): | |
role = gr.Dropdown(choices=["user", "observation"], value="user", label="Role", scale=4) | |
system_prompt = gr.Textbox(label="System Prompt", lines=1, info="Optional") | |
tools = gr.Textbox(label="Tools", lines=1, placeholder="Enter tools in JSON format", info="Optional") | |
with gr.Row(): | |
user_input = gr.Textbox( | |
label="User Input", | |
placeholder="Type your message here...", | |
show_label=True, | |
scale=8 | |
) | |
submit_btn = gr.Button("Submit", variant="primary", elem_id="submit-button") | |
clear_btn = gr.Button("Clear Conversation", elem_id="clear-button") | |
tools.change(validate_tools, tools, error_box) | |
submit_btn.click( | |
user, | |
[user_input, chatbot], | |
[user_input, chatbot], | |
queue=False | |
).then( | |
bot, | |
[chatbot, system_prompt, tools, role, max_new_tokens, temperature], | |
chatbot | |
) | |
clear_btn.click(lambda: ([], None), outputs=[chatbot, error_box]) | |
gr.Markdown(model_table) | |
gr.Markdown(LICENSE) | |
return demo | |
if __name__ == "__main__": | |
# Initialize npm project if package.json doesn't exist | |
if not os.path.exists('package.json'): | |
print("Initializing npm project...") | |
run_command("npm init -y") | |
# Install jsonrepair locally | |
print("Installing jsonrepair...") | |
run_command("npm install jsonrepair") | |
# Verify installation | |
print("Verifying jsonrepair installation:") | |
run_command("npm list jsonrepair") | |
# Add node_modules/.bin to PATH | |
os.environ['PATH'] = f"{os.path.join(os.getcwd(), 'node_modules', '.bin')}:{os.environ['PATH']}" | |
from preprocess import preprocess_input | |
from postprocess import postprocess_output | |
demo = create_chat_interface() | |
demo.queue(max_size=20).launch() | |