Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import re | |
| from dotenv import load_dotenv | |
| from langchain_core.messages import (AIMessage, HumanMessage, SystemMessage, | |
| ToolMessage) | |
| from langchain_huggingface import (ChatHuggingFace, HuggingFaceEmbeddings, | |
| HuggingFaceEndpoint) | |
| from langgraph.graph import START, MessagesState, StateGraph | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from tools import (absolute, add, analyze_csv_file, analyze_excel_file, | |
| arvix_search, audio_transcription, compound_interest, | |
| convert_temperature, divide, exponential, | |
| extract_text_from_image, factorial, floor_divide, | |
| get_current_time_in_timezone, greatest_common_divisor, | |
| is_prime, least_common_multiple, logarithm, modulus, | |
| multiply, percentage_calculator, power, python_code_parser, | |
| reverse_sentence, roman_calculator_converter, square_root, | |
| subtract, web_content_extract, web_search, wiki_search) | |
| # Load Constants | |
| load_dotenv() | |
| HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
| tools = [ | |
| multiply, add, subtract, power, divide, modulus, | |
| square_root, floor_divide, absolute, logarithm, | |
| exponential, web_search, roman_calculator_converter, | |
| get_current_time_in_timezone, compound_interest, | |
| convert_temperature, factorial, greatest_common_divisor, | |
| is_prime, least_common_multiple, percentage_calculator, | |
| wiki_search, analyze_excel_file, arvix_search, | |
| audio_transcription, python_code_parser, analyze_csv_file, | |
| extract_text_from_image, reverse_sentence, web_content_extract, | |
| ] | |
| # Load system prompt | |
| system_prompt = """ | |
| You are a general AI assistant. I will ask you a question. | |
| Report your thoughts, and finish your answer with only the answer, no extra text, no prefix, and no explanation. | |
| Your answer should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. | |
| If you are asked for a number, don't use a comma to write your number, nor use units such as $ or percent sign unless specified otherwise. | |
| If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. | |
| If you are asked for a comma separated list, apply the above rules depending on whether the element to be put in the list is a number or a string. | |
| Format your output as: [{"task_id": ..., "submitted_answer": ...}] | |
| Do NOT include the format string or any JSON inside the submitted_answer field. Only output a single flat list as: [{"task_id": ..., "submitted_answer": ...}] | |
| """ | |
| # System message | |
| sys_msg = SystemMessage(content=system_prompt) | |
| def build_graph(): | |
| """Build the graph""" | |
| # First create the HuggingFaceEndpoint | |
| llm_endpoint = HuggingFaceEndpoint( | |
| repo_id="mistralai/Mistral-7B-Instruct-v0.2", | |
| huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN, | |
| #api_key=GEMINI_API_KEY, | |
| temperature=0.1, | |
| max_new_tokens=1024, | |
| timeout=60, | |
| ) | |
| # Then wrap it with ChatHuggingFace to get chat model functionality | |
| llm = ChatHuggingFace(llm=llm_endpoint) | |
| # Bind tools to LLM | |
| llm_with_tools = llm.bind_tools(tools) | |
| # --- Nodes --- | |
| def extract_answer(llm_output): | |
| # Try to parse as JSON if possible | |
| try: | |
| # If the LLM output is a JSON list, extract the answer | |
| parsed = json.loads(llm_output.strip().split('\n')[0]) | |
| if isinstance(parsed, list) and isinstance(parsed[0], dict) and "submitted_answer" in parsed[0]: | |
| return parsed[0]["submitted_answer"] | |
| except Exception: | |
| pass | |
| # Otherwise, just return the first line (before any explanation) | |
| return llm_output.strip().split('\n')[0] | |
| def assistant(state: MessagesState): | |
| messages_with_system_prompt = [sys_msg] + state["messages"] | |
| llm_response = llm_with_tools.invoke(messages_with_system_prompt) | |
| answer_text = extract_answer(llm_response.content) | |
| task_id = str(state.get("task_id", "1")) # Ensure task_id is a string | |
| formatted = [{"task_id": task_id, "submitted_answer": answer_text}] | |
| return {"messages": [AIMessage(content=json.dumps(formatted, ensure_ascii=False))]} | |
| # --- Graph Definition --- | |
| builder = StateGraph(MessagesState) | |
| builder.add_node("assistant", assistant) | |
| builder.add_node("tools", ToolNode(tools)) | |
| builder.add_edge(START, "assistant") | |
| builder.add_conditional_edges("assistant", tools_condition) | |
| builder.add_edge("tools", "assistant") | |
| # Compile graph | |
| return builder.compile() | |
| def is_valid_agent_output(output): | |
| """ | |
| Checks if the output matches the required format: | |
| Answers (answers): [{"task_id": ..., "submitted_answer": ...}] | |
| """ | |
| # Basic regex to check the format | |
| pattern = r'^Answers \(answers\): \[(\{.*\})\]$' | |
| match = re.match(pattern, output.strip()) | |
| if not match: | |
| return False | |
| # Try to parse the JSON part | |
| try: | |
| answers_list = json.loads(f'[{match.group(1)}]') | |
| # Check required keys | |
| for ans in answers_list: | |
| if not isinstance(ans, dict): | |
| return False | |
| if "task_id" not in ans or "submitted_answer" not in ans: | |
| return False | |
| return True | |
| except Exception: | |
| return False | |
| def extract_flat_answer(output): | |
| # Try to find the innermost Answers (answers): [{...}] | |
| pattern = r'Answers \(answers\): \[(\{.*?\})\]' | |
| matches = re.findall(pattern, output) | |
| if matches: | |
| # Use the last match (innermost) | |
| try: | |
| answers_list = json.loads(f'[{matches[-1]}]') | |
| if isinstance(answers_list, list) and "task_id" in answers_list[0] and "submitted_answer" in answers_list[0]: | |
| return f'Answers (answers): [{matches[-1]}]' | |
| except Exception: | |
| pass | |
| return output # fallback | |
| # test | |
| if __name__ == "__main__": | |
| question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?" | |
| # Build the graph | |
| graph = build_graph() | |
| # Run the graph | |
| messages = [HumanMessage(content=question)] | |
| # The initial state for the graph | |
| initial_state = {"messages": messages, "task_id": "test123"} | |
| # Invoke the graph stream to see the steps | |
| for s in graph.stream(initial_state, stream_mode="values"): | |
| message = s["messages"][-1] | |
| if isinstance(message, ToolMessage): | |
| print("---RETRIEVED CONTEXT---") | |
| print(message.content) | |
| print("-----------------------") | |
| else: | |
| output = message.content # This is a string | |
| try: | |
| parsed = json.loads(output) | |
| if isinstance(parsed, list) and "task_id" in parsed[0] and "submitted_answer" in parsed[0]: | |
| print("✅ Output is in the correct format!") | |
| else: | |
| print("❌ Output is NOT in the correct format!") | |
| except Exception as e: | |
| print("❌ Output is NOT in the correct format!", e) |