wt002 commited on
Commit
ad00d9c
·
verified ·
1 Parent(s): 7622d0c

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +24 -20
agent.py CHANGED
@@ -1059,29 +1059,33 @@ model_config = {
1059
  "huggingfacehub_api_token": os.getenv("HF_TOKEN")
1060
  }
1061
 
1062
- # Get LLM
1063
- def get_llm(provider: str, config: dict):
1064
- if provider == "huggingface":
1065
- from langchain_huggingface import HuggingFaceEndpoint
1066
- return HuggingFaceEndpoint(
1067
- repo_id=config["repo_id"],
1068
- task=config["task"],
1069
- huggingfacehub_api_token=config["huggingfacehub_api_token"],
1070
- temperature=config["temperature"],
1071
- max_new_tokens=config["max_new_tokens"]
1072
- )
1073
- else:
1074
- raise ValueError(f"Unsupported provider: {provider}")
1075
 
1076
 
1077
- def assistant(state: dict):
1078
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
 
 
 
 
 
 
 
1079
 
1080
- def tools_condition(state: dict) -> str:
1081
- if "use tool" in state["messages"][-1].content.lower():
1082
- return "tools"
1083
- else:
1084
- return "END"
1085
 
1086
 
1087
  from langgraph.graph import StateGraph
 
1059
  "huggingfacehub_api_token": os.getenv("HF_TOKEN")
1060
  }
1061
 
1062
+ # Get LLM
1063
+ def get_llm(provider: str, config: dict):
1064
+ if provider == "huggingface":
1065
+ from langchain_huggingface import HuggingFaceEndpoint
1066
+ return HuggingFaceEndpoint(
1067
+ repo_id=config["repo_id"],
1068
+ task=config["task"],
1069
+ huggingfacehub_api_token=config["huggingfacehub_api_token"],
1070
+ temperature=config["temperature"],
1071
+ max_new_tokens=config["max_new_tokens"]
1072
+ )
1073
+ else:
1074
+ raise ValueError(f"Unsupported provider: {provider}")
1075
 
1076
 
1077
+ def assistant(state: dict):
1078
+ return {
1079
+ "messages": [llm_with_tools.invoke(state["messages"])]
1080
+ }
1081
+
1082
+
1083
+ def tools_condition(state: dict) -> str:
1084
+ if "use tool" in state["messages"][-1].content.lower():
1085
+ return "tools"
1086
+ else:
1087
+ return "END"
1088
 
 
 
 
 
 
1089
 
1090
 
1091
  from langgraph.graph import StateGraph