Spaces:
Sleeping
Sleeping
Tidy: includes
Browse files- app.py +8 -7
- pstuts_rag/pstuts_rag/agents.py +87 -0
- pstuts_rag/pstuts_rag/rag.py +6 -15
app.py
CHANGED
|
@@ -1,21 +1,22 @@
|
|
| 1 |
import asyncio
|
| 2 |
-
from typing import List, Dict, Any
|
| 3 |
-
import chainlit as cl
|
| 4 |
import json
|
| 5 |
import os
|
|
|
|
|
|
|
| 6 |
|
|
|
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
-
from langchain_openai import ChatOpenAI
|
| 9 |
-
from langchain_openai.embeddings import OpenAIEmbeddings
|
| 10 |
from langchain_core.documents import Document
|
| 11 |
from langchain_core.language_models import BaseChatModel
|
| 12 |
from langchain_core.runnables import Runnable
|
|
|
|
|
|
|
| 13 |
from langchain_qdrant import QdrantVectorStore
|
| 14 |
-
from pstuts_rag.loader import load_json_files, load_single_json
|
| 15 |
from qdrant_client import QdrantClient
|
| 16 |
-
from dataclasses import dataclass
|
| 17 |
|
| 18 |
-
import pstuts_rag.
|
|
|
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
@dataclass
|
|
|
|
| 1 |
import asyncio
|
|
|
|
|
|
|
| 2 |
import json
|
| 3 |
import os
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any, Dict, List
|
| 6 |
|
| 7 |
+
import chainlit as cl
|
| 8 |
from dotenv import load_dotenv
|
|
|
|
|
|
|
| 9 |
from langchain_core.documents import Document
|
| 10 |
from langchain_core.language_models import BaseChatModel
|
| 11 |
from langchain_core.runnables import Runnable
|
| 12 |
+
from langchain_openai import ChatOpenAI
|
| 13 |
+
from langchain_openai.embeddings import OpenAIEmbeddings
|
| 14 |
from langchain_qdrant import QdrantVectorStore
|
|
|
|
| 15 |
from qdrant_client import QdrantClient
|
|
|
|
| 16 |
|
| 17 |
+
import pstuts_rag.datastore
|
| 18 |
+
import pstuts_rag.rag
|
| 19 |
+
from pstuts_rag.loader import load_json_files
|
| 20 |
|
| 21 |
|
| 22 |
@dataclass
|
pstuts_rag/pstuts_rag/agents.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, List, Optional, TypedDict, Union
|
| 2 |
+
|
| 3 |
+
from langchain.agents import AgentExecutor, create_openai_functions_agent
|
| 4 |
+
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
| 5 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 6 |
+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
| 7 |
+
from langchain_core.runnables import Runnable
|
| 8 |
+
from langchain_core.tools import BaseTool
|
| 9 |
+
from langchain_openai import ChatOpenAI
|
| 10 |
+
from langchain_core.language_models import BaseChatModel
|
| 11 |
+
|
| 12 |
+
from langgraph.graph import END, StateGraph
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def agent_node(state, agent, name):
|
| 16 |
+
result = agent.invoke(state)
|
| 17 |
+
return {"messages": [HumanMessage(content=result["output"], name=name)]}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def create_agent(
|
| 21 |
+
llm: BaseChatModel,
|
| 22 |
+
tools: list,
|
| 23 |
+
system_prompt: str,
|
| 24 |
+
):
|
| 25 |
+
"""Create a function-calling agent and add it to the graph."""
|
| 26 |
+
system_prompt += """
|
| 27 |
+
Work autonomously according to your specialty, using the tools available to you.
|
| 28 |
+
Do not ask for clarification.
|
| 29 |
+
Your other team members (and other teams) will collaborate with you with their own specialties.
|
| 30 |
+
Your first choice should be to use your vector store RAG as the primary source of information.
|
| 31 |
+
If that does not provide enough context, then use the ArXiv search engine as the first attempt,
|
| 32 |
+
and if that does not provide enough context, then use the Tavily search engine as the second attempt.
|
| 33 |
+
|
| 34 |
+
You are chosen for a reason! You are one of the following team members: {team_members}.
|
| 35 |
+
"""
|
| 36 |
+
prompt = ChatPromptTemplate.from_messages(
|
| 37 |
+
[
|
| 38 |
+
(
|
| 39 |
+
"system",
|
| 40 |
+
system_prompt,
|
| 41 |
+
),
|
| 42 |
+
MessagesPlaceholder(variable_name="messages"),
|
| 43 |
+
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
| 44 |
+
]
|
| 45 |
+
)
|
| 46 |
+
agent = create_openai_functions_agent(llm, tools, prompt)
|
| 47 |
+
executor = AgentExecutor(agent=agent, tools=tools)
|
| 48 |
+
return executor
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def create_team_supervisor(llm: ChatOpenAI, system_prompt, members) -> str:
|
| 52 |
+
"""An LLM-based router."""
|
| 53 |
+
options = ["FINISH"] + members
|
| 54 |
+
function_def = {
|
| 55 |
+
"name": "route",
|
| 56 |
+
"description": "Select the next role.",
|
| 57 |
+
"parameters": {
|
| 58 |
+
"title": "routeSchema",
|
| 59 |
+
"type": "object",
|
| 60 |
+
"properties": {
|
| 61 |
+
"next": {
|
| 62 |
+
"title": "Next",
|
| 63 |
+
"anyOf": [
|
| 64 |
+
{"enum": options},
|
| 65 |
+
],
|
| 66 |
+
},
|
| 67 |
+
},
|
| 68 |
+
"required": ["next"],
|
| 69 |
+
},
|
| 70 |
+
}
|
| 71 |
+
prompt = ChatPromptTemplate.from_messages(
|
| 72 |
+
[
|
| 73 |
+
("system", system_prompt),
|
| 74 |
+
MessagesPlaceholder(variable_name="messages"),
|
| 75 |
+
(
|
| 76 |
+
"system",
|
| 77 |
+
"Given the conversation above, who should act next? Or should we FINISH?"
|
| 78 |
+
"If the last answer was 'I don't know', do not FINISH."
|
| 79 |
+
" Select one of: {options}",
|
| 80 |
+
),
|
| 81 |
+
]
|
| 82 |
+
).partial(options=str(options), team_members=", ".join(members))
|
| 83 |
+
return (
|
| 84 |
+
prompt
|
| 85 |
+
| llm.bind_functions(functions=[function_def], function_call="route")
|
| 86 |
+
| JsonOutputFunctionsParser()
|
| 87 |
+
)
|
pstuts_rag/pstuts_rag/rag.py
CHANGED
|
@@ -6,32 +6,25 @@ This module provides the core RAG functionality, including:
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
import json
|
| 9 |
-
from multiprocessing import Value
|
| 10 |
import re
|
| 11 |
-
|
| 12 |
from operator import itemgetter
|
| 13 |
-
from typing import Dict, List,
|
| 14 |
|
|
|
|
| 15 |
from langchain_core.documents import Document
|
|
|
|
|
|
|
| 16 |
from langchain_core.runnables import (
|
| 17 |
Runnable,
|
| 18 |
RunnableLambda,
|
| 19 |
RunnablePassthrough,
|
| 20 |
)
|
| 21 |
-
|
| 22 |
-
from langchain_openai.embeddings import OpenAIEmbeddings
|
| 23 |
-
from langchain_qdrant import QdrantVectorStore
|
| 24 |
-
from qdrant_client import QdrantClient
|
| 25 |
-
|
| 26 |
-
from langchain.prompts import ChatPromptTemplate
|
| 27 |
from langchain_core.vectorstores import VectorStoreRetriever
|
| 28 |
from langchain_openai import ChatOpenAI
|
| 29 |
|
| 30 |
from .prompt_templates import RAG_PROMPT_TEMPLATES
|
| 31 |
|
| 32 |
-
from langchain_core.language_models.base import BaseLanguageModel
|
| 33 |
-
from langchain_core.messages import AIMessage
|
| 34 |
-
|
| 35 |
|
| 36 |
class RAGChainFactory:
|
| 37 |
"""Factory class for creating RAG (Retrieval Augmented Generation) chains.
|
|
@@ -150,9 +143,7 @@ class RAGChainFactory:
|
|
| 150 |
|
| 151 |
def get_rag_chain(
|
| 152 |
self,
|
| 153 |
-
llm:
|
| 154 |
-
model="gpt-4.1-mini", temperature=0
|
| 155 |
-
),
|
| 156 |
) -> Runnable:
|
| 157 |
"""Build and return the complete RAG chain.
|
| 158 |
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
import json
|
|
|
|
| 9 |
import re
|
| 10 |
+
|
| 11 |
from operator import itemgetter
|
| 12 |
+
from typing import Any, Dict, List, Tuple
|
| 13 |
|
| 14 |
+
from langchain.prompts import ChatPromptTemplate
|
| 15 |
from langchain_core.documents import Document
|
| 16 |
+
from langchain_core.language_models import BaseChatModel
|
| 17 |
+
from langchain_core.messages import AIMessage
|
| 18 |
from langchain_core.runnables import (
|
| 19 |
Runnable,
|
| 20 |
RunnableLambda,
|
| 21 |
RunnablePassthrough,
|
| 22 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
from langchain_core.vectorstores import VectorStoreRetriever
|
| 24 |
from langchain_openai import ChatOpenAI
|
| 25 |
|
| 26 |
from .prompt_templates import RAG_PROMPT_TEMPLATES
|
| 27 |
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
class RAGChainFactory:
|
| 30 |
"""Factory class for creating RAG (Retrieval Augmented Generation) chains.
|
|
|
|
| 143 |
|
| 144 |
def get_rag_chain(
|
| 145 |
self,
|
| 146 |
+
llm: BaseChatModel = ChatOpenAI(model="gpt-4.1-mini", temperature=0),
|
|
|
|
|
|
|
| 147 |
) -> Runnable:
|
| 148 |
"""Build and return the complete RAG chain.
|
| 149 |
|