Spaces:
Runtime error
Runtime error
| import os | |
| from threading import Thread | |
| from typing import Iterator | |
| import json | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| import subprocess | |
| import copy | |
| import subprocess | |
| import sys | |
| def run_command(command): | |
| process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) | |
| output, error = process.communicate() | |
| if process.returncode != 0: | |
| print(f"Error executing command: {command}") | |
| print(f"Error message: {error.decode('utf-8')}") | |
| sys.exit(1) | |
| return output.decode('utf-8') | |
| MAX_MAX_NEW_TOKENS = 2048 | |
| DEFAULT_MAX_NEW_TOKENS = 1024 | |
| MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8000")) | |
| model_choices = [ | |
| "rubra-ai/Meta-Llama-3-8B-Instruct", | |
| "rubra-ai/Qwen2-7B-Instruct", | |
| "rubra-ai/Phi-3-mini-128k-instruct", | |
| "rubra-ai/Mistral-7B-Instruct-v0.3", | |
| # "rubra-ai/Mistral-7B-Instruct-v0.2", | |
| # "rubra-ai/gemma-1.1-2b-it" | |
| ] | |
| DESCRIPTION = """\ | |
| # Rubra v0.1 - A Collection of Tool (Function) Calling LLMs | |
| This is a demo of the Rubra collection of models. You can use the models for general conversation, | |
| task completion, and function calling with the provided tools input. | |
| See more at https://docs.rubra.ai/ & https://github.com/rubra-ai/rubra | |
| """ | |
| model_table = """ | |
| <p/> | |
| --- | |
| ## Rubra Benchmarks | |
| | Model | Params (in billions) | Function Calling | MMLU (5-shot) | GPQA (0-shot) | GSM-8K (8-shot, CoT) | MATH (4-shot, CoT) | MT-bench | | |
| |------------------------------------------|----------------------|------------------|---------------|---------------|----------------------|--------------------|----------| | |
| | GPT-4o | - | 98.57% | - | 53.6 | - | - | - | | |
| | Claude-3.5 Sonnet | - | 98.57% | 88.7 | 59.4 | - | - | - | | |
| | Rubra Llama-3 70B Instruct | 70.6 | 97.85% | 75.90 | 33.93 | 82.26 | 34.24 | 8.36 | | |
| | Rubra Llama-3 8B Instruct | 8.9 | 89.28% | 64.39 | 31.70 | 68.99 | 23.76 | 8.03 | | |
| | Rubra Qwen2-7B-Instruct | 8.55 | 85.71% | 68.88 | 30.36 | 75.82 | 28.72 | 8.08 | | |
| | Rubra Mistral 7B Instruct v0.3 | 8.12 | 73.57% | 59.12 | 29.91 | 43.29 | 11.14 | 7.69 | | |
| | Rubra Phi-3 Mini 128k Instruct | 4.73 | 70.00% | 67.87 | 29.69 | 79.45 | 30.80 | 8.21 | | |
| | Rubra Mistral 7B Instruct v0.2 | 8.11 | 69.28% | 58.90 | 29.91 | 34.12 | 8.36 | 7.36 | | |
| | meetkai/functionary-small-v2.5 | 8.03 | 57.14% | 63.92 | 32.14 | 66.11 | 20.54 | 7.09 | | |
| | Nexusflow/NexusRaven-V2-13B | 13.0 | 53.75% ∔ | 43.23 | 28.79 | 22.67 | 7.12 | 5.36 | | |
| | Mistral Large (closed-source) | - | 48.60% | - | - | 91.21 | 45.0 | - | | |
| | Rubra Gemma-1.1 2B Instruct | 2.84 | 45.00% | 38.85 | 24.55 | 6.14 | 2.38 | 5.75 | | |
| | meetkai/functionary-medium-v3.0 | 70.6 | 46.43% | 79.85 | 38.39 | 89.54 | 43.02 | 5.49 | | |
| | gorilla-llm/gorilla-openfunctions-v2 | 6.91 | 41.25% ∔ | 49.14 | 23.66 | 48.29 | 17.54 | 5.13 | | |
| | NousResearch/Hermes-2-Pro-Llama-3-8B | 8.03 | 41.25% | 64.16 | 31.92 | 73.92 | 21.58 | 7.83 | | |
| | Mistral 7B Instruct v0.3 | 7.25 | 22.5% | 62.10 | 30.58 | 53.07 | 12.98 | 7.50 | | |
| | Gemma-1.1 2B Instruct | 2.51 | - | 37.84 | 22.99 | 6.29 | 6.14 | 5.82 | | |
| | Llama-3 8B Instruct | 8.03 | - | 65.69 | 31.47 | 77.41 | 27.58 | 8.07 | | |
| | Llama-3 70B Instruct | 70.6 | - | 79.90 | 38.17 | 90.67 | 44.24 | 8.88 | | |
| | Mistral 7B Instruct v0.2 | 7.24 | - | 59.27 | 27.68 | 43.21 | 10.30 | 7.50 | | |
| | Phi-3 Mini 128k Instruct | 3.82 | - | 69.36 | 27.01 | 83.7 | 32.92 | 8.02 | | |
| | Qwen2-7B-Instruct | 7.62 | - | 70.78 | 32.14 | 78.54 | 30.10 | 8.29 | | |
| ∔ `Nexusflow/NexusRaven-V2-13B` and `gorilla-llm/gorilla-openfunctions-v2` don't accept tool observations, the result of running a tool or function once the LLM calls it, so we appended the observation to the prompt. | |
| """ | |
| LICENSE = """ | |
| <p/> | |
| --- | |
| Rubra code is licensed under the Apache License, Version 2.0 (the "License"); | |
| you may not use this file except in compliance with the License. | |
| You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software | |
| distributed under the License is distributed on an "AS IS" BASIS, | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | |
| limitations under the License. | |
| Rubra models are licensed under the parent model's license. See the parent model card for more information. | |
| """ | |
| if not torch.cuda.is_available(): | |
| DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" | |
| if torch.cuda.is_available(): | |
| model_id = "rubra-ai/Meta-Llama-3-8B-Instruct" # Default model | |
| model = None | |
| tokenizer = None | |
| def load_model(model_name): | |
| global model, tokenizer | |
| model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_4bit=False) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model.generation_config.pad_token_id = tokenizer.pad_token_id | |
| load_model(model_id) # Load the default model | |
| def is_valid_json(tools: str) -> bool: | |
| try: | |
| json.loads(tools) | |
| return True | |
| except ValueError: | |
| return False | |
| def validate_tools(tools): | |
| if tools.strip() == "" or is_valid_json(tools): | |
| return gr.update(visible=False) | |
| else: | |
| return gr.update(visible=True) | |
| def json_to_markdown(json_obj): | |
| """Convert a JSON object to a formatted markdown string.""" | |
| markdown = "" | |
| for item in json_obj: | |
| if item.get("type") == "text": | |
| # For text items, just add the text content | |
| markdown += item.get("text", "") + "\n\n" | |
| elif item.get("type") == "function": | |
| # For function calls, format as JSON | |
| markdown += "```json\n" | |
| # markdown += json.dumps(item.get("function", {}), indent=2) | |
| markdown += json.dumps(item, indent=2) | |
| markdown += "\n```\n\n" | |
| return markdown.strip() | |
| def user(user_message, history): | |
| return "", history + [[user_message, None]] | |
| def bot(history, system_prompt, tools, role, max_new_tokens, temperature): | |
| user_message = history[-1][0] | |
| if history[-1][1] is None: | |
| history[-1][1] = "" # Ensure it's never None | |
| ui_history = list(history) # Clone the history for UI updates | |
| all_tool_outputs = [] # Store all processed outputs for final aggregation | |
| output_accumulated = "" # To accumulate outputs before processing | |
| for chunk in generate(user_message, history[:-1], system_prompt, tools, role, max_new_tokens, temperature): | |
| history[-1][1] += chunk | |
| print(history[-1][1]) | |
| if "endtoolcall" in history[-1][1]: | |
| process_output = postprocess_output(history[-1][1]) | |
| print("process output:\n", process_output) | |
| if process_output: | |
| temp_history = copy.deepcopy(history) # Use deepcopy here | |
| if isinstance(process_output, list) and len(process_output) > 0 and isinstance(process_output[0], dict): | |
| markdown_output = json_to_markdown(process_output) | |
| temp_history[-1][1] = markdown_output | |
| else: | |
| temp_history[-1][1] = str(process_output) | |
| print(temp_history[-1][1]) | |
| print("--------------------------") | |
| yield temp_history | |
| else: | |
| print(history[-1][1]) | |
| print("--------------------------") | |
| yield history | |
| else: | |
| print(history[-1][1]) | |
| print("--------------------------") | |
| yield history | |
| def generate( | |
| message: str, | |
| chat_history: list[tuple[str, str]], | |
| system_prompt: str, | |
| tools: str, | |
| role: str, | |
| max_new_tokens: int = 1024, | |
| temperature: float = 0.6, | |
| ) -> Iterator[str]: | |
| global model, tokenizer | |
| conversation = [] | |
| if system_prompt: | |
| conversation.append({"role": "system", "content": system_prompt}) | |
| for user, assistant in chat_history: | |
| conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) | |
| conversation.append({"role": role, "content": message}) | |
| if tools: | |
| if not is_valid_json(tools): | |
| yield "Invalid JSON in tools. Please correct it." | |
| return | |
| tools = json.loads(tools) | |
| formatted_msgs = preprocess_input(msgs=conversation, tools=tools) | |
| else: | |
| formatted_msgs = conversation | |
| input_ids = tokenizer.apply_chat_template(formatted_msgs, return_tensors="pt") | |
| if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
| input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
| gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
| input_ids = input_ids.to(model.device) | |
| streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| input_ids=input_ids, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| top_p=0.95, | |
| temperature=temperature, | |
| num_beams=1, | |
| repetition_penalty=1.2, | |
| ) | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| for text in streamer: | |
| # print("Generated text:", text) | |
| yield text | |
| bot_message = """Hello! How can I assist you today? If you have any questions or need information on a specific topic, feel free to ask. I can also utilize `tools` that you input to help you better. For example: | |
| ``` | |
| [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "get_stock_information", | |
| "description": "Get the current stock market information for a given company", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "ticker_symbol": { | |
| "type": "string", | |
| "description": "The stock ticker symbol of the company, e.g., 'AAPL' for Apple Inc." | |
| }, | |
| "exchange": { | |
| "type": "string", | |
| "description": "The stock exchange where the company is listed, e.g., 'NASDAQ'. If not provided, default to the primary exchange for the ticker symbol." | |
| }, | |
| "data_type": { | |
| "type": "string", | |
| "enum": ["price", "volume", "market_cap"], | |
| "description": "The type of stock data to retrieve: 'price' for current price, 'volume' for trading volume, 'market_cap' for market capitalization." | |
| } | |
| }, | |
| "required": ["ticker_symbol", "data_type"] | |
| } | |
| } | |
| } | |
| ] | |
| ``` | |
| You can also define `functions` (deprecated in favor of `tools` in OpenAI): | |
| ``` | |
| [ | |
| { | |
| "name": "get_current_date", | |
| "description": "Gets the current date at the given location. Results are in ISO 8601 date format; e.g. 2024-04-25", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "location": { | |
| "type": "string", | |
| "description": "The city and state to get the current date at, e.g. San Francisco, CA" | |
| } | |
| }, | |
| "required":["location"] | |
| } | |
| } | |
| ] | |
| ``` | |
| """ | |
| def create_chat_interface(): | |
| with gr.Blocks(css="style.css") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(equal_height=True, elem_id="main-row"): | |
| with gr.Column(scale=3, min_width=500): | |
| # Initialize the chatbot with the welcome message | |
| chatbot = gr.Chatbot( | |
| value=[("Hi", bot_message)], | |
| show_copy_button=True, | |
| elem_id="chatbot", | |
| show_label=False, | |
| render_markdown=True, | |
| height="100%", | |
| layout='bubble', | |
| avatar_images=("human.png", "bot.png") | |
| ) | |
| error_box = gr.Markdown(visible=False, elem_id="error-box") | |
| with gr.Column(scale=2, min_width=300): | |
| model_dropdown = gr.Dropdown( | |
| choices=model_choices, | |
| label="Select Model", | |
| value="rubra-ai/Meta-Llama-3-8B-Instruct" | |
| ) | |
| model_dropdown.change(load_model, inputs=[model_dropdown]) | |
| with gr.Accordion("Settings", open=False): | |
| max_new_tokens = gr.Slider( | |
| label="Max new tokens", | |
| minimum=1, | |
| maximum=MAX_MAX_NEW_TOKENS, | |
| step=1, | |
| value=DEFAULT_MAX_NEW_TOKENS, | |
| ) | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| minimum=0.0, | |
| maximum=1.2, | |
| step=0.01, | |
| value=0.01, | |
| ) | |
| with gr.Row(): | |
| role = gr.Dropdown(choices=["user", "observation"], value="user", label="Role", scale=4) | |
| system_prompt = gr.Textbox(label="System Prompt", lines=1, info="Optional") | |
| tools = gr.Textbox(label="Tools", lines=1, placeholder="Enter tools in JSON format", info="Optional") | |
| with gr.Row(): | |
| user_input = gr.Textbox( | |
| label="User Input", | |
| placeholder="Type your message here...", | |
| show_label=True, | |
| scale=8 | |
| ) | |
| submit_btn = gr.Button("Submit", variant="primary", elem_id="submit-button") | |
| clear_btn = gr.Button("Clear Conversation", elem_id="clear-button") | |
| tools.change(validate_tools, tools, error_box) | |
| submit_btn.click( | |
| user, | |
| [user_input, chatbot], | |
| [user_input, chatbot], | |
| queue=False | |
| ).then( | |
| bot, | |
| [chatbot, system_prompt, tools, role, max_new_tokens, temperature], | |
| chatbot | |
| ) | |
| clear_btn.click(lambda: ([], None), outputs=[chatbot, error_box]) | |
| gr.Markdown(model_table) | |
| gr.Markdown(LICENSE) | |
| return demo | |
| if __name__ == "__main__": | |
| # Initialize npm project if package.json doesn't exist | |
| if not os.path.exists('package.json'): | |
| print("Initializing npm project...") | |
| run_command("npm init -y") | |
| # Install jsonrepair locally | |
| print("Installing jsonrepair...") | |
| run_command("npm install jsonrepair") | |
| # Verify installation | |
| print("Verifying jsonrepair installation:") | |
| run_command("npm list jsonrepair") | |
| # Add node_modules/.bin to PATH | |
| os.environ['PATH'] = f"{os.path.join(os.getcwd(), 'node_modules', '.bin')}:{os.environ['PATH']}" | |
| from preprocess import preprocess_input | |
| from postprocess import postprocess_output | |
| demo = create_chat_interface() | |
| demo.queue(max_size=20).launch() | |