Spaces:
Sleeping
Sleeping
| # app.py | |
| import json | |
| import os | |
| import sys | |
| import time | |
| import logging | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| from neo4j import GraphDatabase | |
| from openai import OpenAI | |
| import subprocess | |
| import time | |
| import requests | |
| import zipfile | |
| import gradio as gr | |
| # Import dotenv to load environment variables from a .env file | |
| from dotenv import load_dotenv | |
| # Load environment variables from a .env file | |
| load_dotenv() | |
| # ----------------------------- | |
| # Configuration | |
| # ----------------------------- | |
| # Use os.getenv to load environment variables | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| NEO4J_URI = os.getenv("NEO4J_URI") | |
| NEO4J_USER = os.getenv("NEO4J_USER") | |
| NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD") | |
| NEO4J_DATABASE = os.getenv("NEO4J_DATABASE") # Use a default if not set | |
| CHAT_MODEL = os.getenv("CHAT_MODEL", " gpt-o1") # Default to gpt-4o-mini if not set | |
| EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small") # Default if not set | |
| # ----------------------------- | |
| # Logging setup | |
| # ----------------------------- | |
| # Basic logging configuration for the application | |
| logger = logging.getLogger("proc-assistant") | |
| handler = logging.StreamHandler(sys.stdout) | |
| fmt = logging.Formatter( | |
| "%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| ) | |
| handler.setFormatter(fmt) | |
| # Clear existing handlers to avoid duplicate logs in some environments | |
| if logger.hasHandlers(): | |
| logger.handlers.clear() | |
| logger.addHandler(handler) | |
| logger.setLevel(logging.INFO) | |
| # ----------------------------- | |
| # SYSTEM_PROMPT | |
| # ----------------------------- | |
| SYSTEM_PROMPT = """ | |
| You are a Signavio Repository assistant and trying to answer questions based on a repository | |
| which can be accessed via cypher queries. | |
| You have access to the following tools: | |
| - get_nodes() | |
| - get_index() | |
| - submit_cypher() | |
| 1. Always first call the functions - get_nodes() get_index() | |
| so that you have a complete list of nodes and full text index. | |
| 2. Answer questions by creating cypher queries you submit over the tool submit_cypher() | |
| 3. Always query the name and Description of a node. | |
| 4. Where an index for the node exists, create a query using a fulltext search | |
| 5. Always request related objects using a generic query 'match (n)-[]-(m)' to retrieve all nodes 'm' related to the node 'n' | |
| 6. If no responses are found (empty results), retry with a broader or parent term (e.g., "procurement" if "indirect | |
| procurement" fails). | |
| 7. Use multiple function calls if several processes are needed. | |
| 8. Use only the returned results to answer and always cite the nodes names and Type. | |
| 9. If results are poor, ask the user to clarify, especially to clarify which node | |
| supports answering the question best. | |
| 10. Inform the user the exact cypher query you submit. | |
| """ | |
| # ----------------------------- | |
| # Tool implementation / get_nodes | |
| # ----------------------------- | |
| def get_nodes() -> str: | |
| """ | |
| Returns a list of all node labels in the database. | |
| """ | |
| logger.info("Getting all node labels.") | |
| cypher = "MATCH (n) RETURN DISTINCT labels(n) AS labels" | |
| try: | |
| with driver.session(database=NEO4J_DATABASE) as sess: | |
| result = [label for record in sess.run(cypher) for label in record["labels"]] | |
| logger.info("Retrieved %d distinct node labels.", len(result)) | |
| return json.dumps(result, ensure_ascii=False) | |
| except Exception as e: | |
| logger.error("Error getting node labels: %s", e) | |
| return json.dumps({"error": str(e)}, ensure_ascii=False) | |
| # ----------------------------- | |
| # Tool implementation / get_relations | |
| # ----------------------------- | |
| def get_relations() -> str: | |
| """ | |
| Returns a list of all relationship types in the database. | |
| """ | |
| logger.info("Getting all relationship types.") | |
| cypher = "MATCH ()-[r]-() RETURN DISTINCT type(r) AS type" | |
| try: | |
| with driver.session(database=NEO4J_DATABASE) as sess: | |
| result = [record["type"] for record in sess.run(cypher)] | |
| logger.info("Retrieved %d distinct relationship types.", len(result)) | |
| return json.dumps(result, ensure_ascii=False) | |
| except Exception as e: | |
| logger.error("Error getting relationship types: %s", e) | |
| return json.dumps({"error": str(e)}, ensure_ascii=False) | |
| # ----------------------------- | |
| # Tool implementation / get_index | |
| # ----------------------------- | |
| def get_index() -> str: | |
| """ | |
| Returns a list of all fulltext index names in the database. | |
| """ | |
| logger.info("Getting all fulltext index names.") | |
| cypher = "SHOW FULLTEXT INDEXES YIELD name" | |
| try: | |
| with driver.session(database=NEO4J_DATABASE) as sess: | |
| result = [record["name"] for record in sess.run(cypher)] | |
| logger.info("Retrieved %d fulltext index names.", len(result)) | |
| return json.dumps(result, ensure_ascii=False) | |
| except Exception as e: | |
| logger.error("Error getting fulltext index names: %s", e) | |
| return json.dumps({"error": str(e)}, ensure_ascii=False) | |
| # ----------------------------- | |
| # Tool implementation / submit_cypher | |
| # ----------------------------- | |
| def submit_cypher(cypher_query: str) -> str: | |
| """ | |
| Executes a Cypher query against the Neo4j database and returns the results | |
| as a JSON string. | |
| """ | |
| logger.info("Executing Cypher query: %s", cypher_query) | |
| try: | |
| with driver.session(database=NEO4J_DATABASE) as sess: | |
| result = sess.run(cypher_query).data() | |
| logger.info("Cypher query returned %d rows.", len(result)) | |
| return json.dumps(result, ensure_ascii=False) | |
| except Exception as e: | |
| logger.error("Error executing Cypher query: %s", e) | |
| return json.dumps({"error": str(e)}, ensure_ascii=False) | |
| # ----------------------------- | |
| # Tools definition | |
| # ----------------------------- | |
| TOOLS = [{ | |
| "type": "function", | |
| "function": { | |
| "name": "get_nodes", | |
| "description": "Returns a list of all node labels in the database.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": {}, | |
| "required": [], | |
| "additionalProperties": False | |
| } | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "get_relations", | |
| "description": "Returns a list of all relationship types in the database.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": {}, | |
| "required": [], | |
| "additionalProperties": False | |
| } | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "get_index", | |
| "description": "Returns a list of all fulltext index names in the database.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": {}, | |
| "required": [], | |
| "additionalProperties": False | |
| } | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "submit_cypher", | |
| "description": "Executes a Cypher query against the Neo4j database and returns the results as a JSON string.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "cypher_query": { | |
| "type": "string", | |
| "description": "The Cypher query to execute." | |
| } | |
| }, | |
| "required": ["cypher_query"], | |
| "additionalProperties": False | |
| } | |
| } | |
| }] | |
| # ----------------------------- | |
| # Chatbot assistant function | |
| # ----------------------------- | |
| def assistant_reply(client: OpenAI, user_query: str, history: list[dict] = []) -> str: | |
| """ | |
| Provides a conversational response to the user query using the OpenAI Chat API, | |
| optionally calling tools to get relevant information from Neo4j. | |
| Args: | |
| client: The initialized OpenAI client object. | |
| user_query: The user's question or request. | |
| history: List of previous message dictionaries in the conversation, | |
| formatted as [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}, ...]. | |
| Returns: | |
| The assistant's reply as a string. | |
| """ | |
| logger.info("assistant_reply received user_query: %r", user_query) | |
| # Log history carefully to avoid excessive output with large histories | |
| if len(history) < 20: # Log full history if relatively small | |
| logger.info("assistant_reply received history: %r", history) | |
| else: # Log summary if history is longer | |
| logger.info("assistant_reply received history length: %d", len(history)) | |
| logger.info("assistant_reply first 10 history items: %r", history[:10]) | |
| logger.info("assistant_reply last 10 history items: %r", history[-10:]) | |
| # Construct the initial messages list including system prompt and the provided history | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| # Extend messages with the received history (already in the correct role/content dict format from chatbot_interface) | |
| messages.extend(history) | |
| # Append the current user query as the latest user message dictionary | |
| messages.append({"role": "user", "content": user_query}) | |
| logger.info("Messages prepared for OpenAI API (first call): %r", messages) | |
| try: | |
| # First API call: Get assistant's response or tool call | |
| logger.info("Calling OpenAI chat completion (first call) with model: %s", CHAT_MODEL) | |
| response = client.chat.completions.create( | |
| model=CHAT_MODEL, | |
| messages=messages, | |
| tools=TOOLS, | |
| tool_choice="auto", | |
| ) | |
| response_message = response.choices[0].message | |
| logger.info("OpenAI response (first call) message object: %s", response_message) # Log the full message object | |
| # Handle tool calls | |
| if response_message.tool_calls: | |
| logger.info("Tool calls detected: %s", response_message.tool_calls) | |
| available_functions = { | |
| "get_nodes": get_nodes, | |
| "get_relations": get_relations, | |
| "get_index": get_index, | |
| "submit_cypher": submit_cypher, | |
| } | |
| # Append the assistant's message with tool_calls to the messages list | |
| # This message object returned by client.chat.completions.create is usually directly compatible | |
| # with the 'messages' format for the subsequent API call in a tool-using turn. | |
| messages.append(response_message) | |
| logger.info("Messages after appending assistant tool_calls message: %r", messages) | |
| # Execute tool calls and append results | |
| for tool_call in response_message.tool_calls: | |
| function_name = tool_call.function.name | |
| # Ensure the function name exists in available_functions | |
| function_to_call = available_functions.get(function_name) | |
| if function_to_call: | |
| try: | |
| # Parse arguments from the tool call message | |
| function_args = json.loads(tool_call.function.arguments) | |
| logger.info("Parsed function arguments: %r", function_args) | |
| # Call the tool function | |
| logger.info("Calling tool function: %s with args: %r", function_name, function_args) | |
| # Ensure tool functions return a string or JSON string as content | |
| # Adjust how arguments are passed based on the function's signature | |
| if function_name == "submit_cypher": | |
| function_response_content = function_to_call( | |
| function_args.get("cypher_query") | |
| ) | |
| else: # For get_nodes, get_relations, get_index which take no arguments | |
| function_response_content = function_to_call() | |
| logger.info("Tool function response content (first 500 chars): %s", function_response_content[:500]) | |
| # Append tool output message to the messages list in the correct format | |
| # Role must be 'tool', content is the string output, and must include tool_call_id and name | |
| messages.append( | |
| { | |
| "tool_call_id": tool_call.id, # Required for tool response messages | |
| "role": "tool", | |
| "content": function_response_content, # Content must be a string | |
| "name": function_name # Required for tool response messages | |
| } | |
| ) | |
| logger.info("Appended tool output message: %r", messages[-1]) | |
| except json.JSONDecodeError: | |
| logger.error("Error decoding function arguments JSON: %s", tool_call.function.arguments) | |
| # Append an error message as tool output if args are invalid | |
| messages.append( | |
| { | |
| "tool_call_id": tool_call.id, | |
| "role": "tool", | |
| "content": f"Error: Invalid JSON arguments for tool '{function_name}'.", | |
| "name": function_name | |
| } | |
| ) | |
| logger.info("Appended tool invalid args error message: %r", messages[-1]) | |
| except Exception as e: | |
| logger.error("Error executing tool '%s': %s", function_name, e) | |
| logger.exception("Tool execution traceback:") | |
| # Append an error message as tool output | |
| messages.append( | |
| { | |
| "tool_call_id": tool_call.id, | |
| "role": "tool", | |
| "content": f"Error executing tool '{function_name}': {e}", # Content as string | |
| "name": function_name | |
| } | |
| ) | |
| logger.info("Appended tool execution error message: %r", messages[-1]) | |
| else: | |
| logger.warning("Function '%s' called by model not found in available_functions", function_name) | |
| # If the model hallucinates a tool call, append a tool message indicating it wasn't found. | |
| messages.append( | |
| { | |
| "tool_call_id": tool_call.id, | |
| "role": "tool", | |
| "content": f"Error: Tool '{function_name}' not found.", | |
| "name": function_name # Still include the name from the model's call | |
| } | |
| ) | |
| logger.info("Appended 'tool not found' message: %r", messages[-1]) | |
| # Second API call: Get final response after tool execution | |
| logger.info("Calling OpenAI chat completion (second call after tools) with model: %s", CHAT_MODEL) | |
| logger.info("Messages prepared for OpenAI API (second call): %r", messages) | |
| second_response = client.chat.completions.create( | |
| model=CHAT_MODEL, | |
| messages=messages, # Send the updated messages list including assistant tool_calls and tool outputs | |
| ) | |
| final_response_message = second_response.choices[0].message | |
| logger.info("OpenAI response (second call) message object: %s", final_response_message) # Log the full message object | |
| # The content of the second response is the final assistant reply | |
| # Ensure the final response content is a string | |
| final_response = final_response_message.content if final_response_message.content is not None else "" | |
| logger.info("Final assistant response content (after tool): %s", final_response) | |
| return final_response | |
| else: | |
| # No function call needed, return the initial response content | |
| logger.info("No tool calls detected. Returning initial response content.") | |
| # Ensure the initial response content is a string | |
| initial_response_content = response_message.content if response_message.content is not None else "" | |
| logger.info("Initial assistant response content: %s", initial_response_content) | |
| return initial_response_content | |
| except Exception as e: | |
| logger.exception("An error occurred during OpenAI chat completion:") # Log the full exception traceback | |
| # Return a user-friendly error message | |
| return f"An error occurred while processing your request: {e}" # Include error message for debugging | |
| # ----------------------------- | |
| # Clients | |
| # ----------------------------- | |
| # Ensure these variables are defined in previous cells or loaded from environment variables | |
| # OPENAI_API_KEY, NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD, NEO4J_DATABASE | |
| # Initialize clients using environment variables | |
| try: | |
| client = OpenAI(api_key=OPENAI_API_KEY) | |
| driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD)) | |
| logger.info("Clients (OpenAI and Neo4j) initialized successfully.") | |
| except Exception as e: | |
| logger.error("Error initializing clients: %s", e) | |
| # Depending on the severity, you might want to exit or handle this differently | |
| # For a Gradio app, you might want to show an error message on the UI | |
| # ----------------------------- | |
| # Gradio Interface | |
| # ----------------------------- | |
| # Get the logger for the chatbot interface function as well | |
| logger = logging.getLogger("proc-assistant") | |
| def chatbot_interface(user_query: str, history: list[dict]) -> tuple[list[dict], list[dict]]: | |
| """ | |
| Wrapper function for assistant_reply to be used with Gradio. | |
| Manages conversation history. Accepts and returns history in OpenAI messages format. | |
| """ | |
| logger.info(f"chatbot_interface received history (dict format): {history}") | |
| logger.info(f"chatbot_interface received user_query: {user_query}") | |
| # The history is already in the correct format (list[dict]) due to gr.Chatbot type='messages' | |
| # Get the response from the assistant | |
| try: | |
| # Pass the history directly to assistant_reply (it expects list[dict]) | |
| # Pass the client object to assistant_reply | |
| assistant_response_content = assistant_reply(client, user_query, history) | |
| logger.info(f"chatbot_interface received assistant_response_content: {assistant_response_content}") | |
| assistant_message = {"role": "assistant", "content": assistant_response_content} | |
| except Exception as e: | |
| logger.exception("Error calling assistant_reply from chatbot_interface:") | |
| assistant_response_content = "An error occurred while getting the assistant's response. Please check the logs for details." # Provide a fallback message | |
| assistant_message = {"role": "assistant", "content": assistant_response_content} | |
| # Append the new interaction (user query and assistant response) to the history | |
| # History is already in list[dict] format, so we append the new messages | |
| updated_history = history + [{"role": "user", "content": user_query}, assistant_message] | |
| logger.info(f"chatbot_interface updated history (dict format): {updated_history}") | |
| # Return the updated history for the Chatbot component and the state | |
| # Both need the history in list[dict] format now because gr.Chatbot has type='messages' | |
| return updated_history, updated_history | |
| # Create the Gradio interface with history components | |
| my_theme = gr.themes.Soft( | |
| primary_hue="yellow", | |
| secondary_hue="gray", | |
| neutral_hue="zinc", | |
| radius_size="sm", | |
| spacing_size="md" | |
| ).set( | |
| body_background_fill="#FFFFFF", | |
| block_background_fill="#FFFFFF", | |
| block_border_color="#E5E5E5", | |
| input_background_fill="#FAFAFA", | |
| input_border_color="#DDDDDD", | |
| button_primary_background_fill="#F2C200", | |
| button_primary_text_color="#000000", | |
| button_secondary_background_fill="#FFFFFF", | |
| button_secondary_text_color="#111111" | |
| ) | |
| with gr.Blocks(theme=my_theme) as iface: | |
| gr.HTML(""" | |
| <div style=" | |
| display:flex; | |
| align-items:center; | |
| gap:20px; | |
| padding:15px 0; | |
| "> | |
| <a href="https://www.bpexperts.de" target="_blank" style="text-decoration:none;"> | |
| <img src='https://images.squarespace-cdn.com/content/v1/62835c73f824d0627cfba7a7/093df9f4-89e4-48f9-8359-2096efe5b65a/Logo_bp_experts_2019.png' | |
| alt="bpExperts Logo" | |
| style="height:70px; object-fit:contain;"> | |
| </a> | |
| <h1 style=" | |
| color:#ffffff; | |
| margin:0; | |
| font-weight:600; | |
| font-size:32px; | |
| "> | |
| Business Flows Assistant | |
| </h1> | |
| </div> | |
| """) | |
| # Use gr.Chatbot with type='messages' to handle history in OpenAI message format | |
| chatbot = gr.Chatbot(label="Process Assistant Chatbot", type='messages') | |
| msg = gr.Textbox(label="Ask a question about business processes.") | |
| clear = gr.ClearButton([msg, chatbot]) | |
| # gr.State is used to maintain the history between interactions | |
| # Initialize state as an empty list for history in dict format | |
| state = gr.State([]) | |
| # The submit method connects the input (msg, state) to the function (chatbot_interface) | |
| # and updates the chatbot and state with the returned history (now in dict format). | |
| msg.submit(chatbot_interface, inputs=[msg, state], outputs=[chatbot, state]) | |
| # ----------------------------- | |
| # Launch the Gradio interface | |
| # ----------------------------- | |
| # This part will run when the app.py file is executed | |
| if __name__ == "__main__": | |
| logger.info("Launching Gradio interface.") | |
| # Use host="0.0.0.0" to make the app accessible externally in a container environment | |
| # Set share=False for production deployment | |
| iface.launch(server_name="0.0.0.0", share=False) | |
| logger.info("Gradio interface launched.") |