wt002 commited on
Commit
e0e02ab
·
verified ·
1 Parent(s): d977a88

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +7 -42
agent.py CHANGED
@@ -13,7 +13,7 @@ from langchain_community.document_loaders import WikipediaLoader
13
  from langchain_community.utilities import WikipediaAPIWrapper
14
  from langchain_community.document_loaders import ArxivLoader
15
  from langchain_core.messages import SystemMessage, HumanMessage
16
- from langchain_core.tools import tool
17
  from langchain.tools.retriever import create_retriever_tool
18
  from supabase.client import Client, create_client
19
  from sentence_transformers import SentenceTransformer
@@ -55,6 +55,8 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
55
  from huggingface_hub import login
56
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
57
  from langchain_huggingface import HuggingFaceEndpoint
 
 
58
 
59
  load_dotenv()
60
 
@@ -643,43 +645,7 @@ def process_all_tasks(tasks: list):
643
  return results
644
 
645
 
646
- def process_question(question: str):
647
- tasks = planner(question)
648
- print(f"Tasks to perform: {tasks}")
649
-
650
- task_type = task_classifier(question)
651
- print(f"Task type: {task_type}")
652
-
653
- state = {"question": question, "last_response": "", "messages": [HumanMessage(content=question)]}
654
- next_task = decide_task(state)
655
- print(f"Next task: {next_task}")
656
-
657
- if node_skipper(state):
658
- print(f"Skipping task: {next_task}")
659
- return "Task skipped."
660
 
661
- try:
662
- if task_type == "wiki_search":
663
- response = wiki_tool.run(question)
664
- elif task_type == "math":
665
- # You should dynamically parse these inputs in real use
666
- response = calc_tool.run(question)
667
- elif task_type == "retriever":
668
- retrieval_result = retriever(state)
669
- response = retrieval_result["messages"][-1].content
670
- else:
671
- response = "Default fallback answer."
672
-
673
- return generate_final_answer(state, {task_type: response})
674
-
675
- except Exception as e:
676
- print(f"❌ Error: {e}")
677
- return "Sorry, I encountered an error processing your request."
678
-
679
-
680
- # ----------------------------------------------------------------
681
- # Process Function (Main Agent Runner) OLD Code
682
- # ----------------------------------------------------------------
683
  def process_question(question: str):
684
  tasks = planner(question)
685
  print(f"Tasks to perform: {tasks}")
@@ -712,7 +678,7 @@ def process_question(question: str):
712
  except Exception as e:
713
  print(f"❌ Error: {e}")
714
  return "Sorry, I encountered an error processing your request."
715
-
716
 
717
 
718
 
@@ -728,7 +694,7 @@ model_config = {
728
  }
729
 
730
  def build_graph(provider, model_config):
731
- from langgraph.prebuilt.tool_node import ToolNode
732
 
733
  def get_llm(provider: str, config: dict):
734
  if provider == "huggingface":
@@ -745,11 +711,10 @@ def build_graph(provider, model_config):
745
 
746
 
747
  llm = get_llm(provider, model_config)
 
 
748
  llm_with_tools = llm.bind_tools(tools)
749
 
750
- # Continue building graph logic here...
751
- # builder = StateGraph(...)
752
- # return builder.compile()
753
 
754
 
755
  sys_msg = SystemMessage(content="You are a helpful assistant.")
 
13
  from langchain_community.utilities import WikipediaAPIWrapper
14
  from langchain_community.document_loaders import ArxivLoader
15
  from langchain_core.messages import SystemMessage, HumanMessage
16
+ #from langchain_core.tools import tool
17
  from langchain.tools.retriever import create_retriever_tool
18
  from supabase.client import Client, create_client
19
  from sentence_transformers import SentenceTransformer
 
55
  from huggingface_hub import login
56
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
57
  from langchain_huggingface import HuggingFaceEndpoint
58
+ from langchain.agents.tools import Tool
59
+
60
 
61
  load_dotenv()
62
 
 
645
  return results
646
 
647
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
649
  def process_question(question: str):
650
  tasks = planner(question)
651
  print(f"Tasks to perform: {tasks}")
 
678
  except Exception as e:
679
  print(f"❌ Error: {e}")
680
  return "Sorry, I encountered an error processing your request."
681
+
682
 
683
 
684
 
 
694
  }
695
 
696
  def build_graph(provider, model_config):
697
+ # from langgraph.prebuilt.tool_node import ToolNode
698
 
699
  def get_llm(provider: str, config: dict):
700
  if provider == "huggingface":
 
711
 
712
 
713
  llm = get_llm(provider, model_config)
714
+ return llm
715
+
716
  llm_with_tools = llm.bind_tools(tools)
717
 
 
 
 
718
 
719
 
720
  sys_msg = SystemMessage(content="You are a helpful assistant.")