Spaces:
Sleeping
Sleeping
| from pstuts_rag.configuration import Configuration | |
| from pstuts_rag.datastore import fill_the_db | |
| from pstuts_rag.graph import build_the_graph | |
| from pstuts_rag.state import PsTutsTeamState | |
| import requests | |
| import asyncio | |
| import json | |
| import os | |
| import getpass | |
| from typing import List, Tuple | |
| import re | |
| import chainlit as cl | |
| from dotenv import load_dotenv | |
| from langchain_core.documents import Document | |
| from langchain_core.language_models import BaseChatModel | |
| from langchain_core.runnables import Runnable | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.embeddings import Embeddings | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_core.messages import HumanMessage, BaseMessage | |
| import langgraph.graph | |
| import pstuts_rag.datastore | |
| import pstuts_rag.rag | |
| import nest_asyncio | |
| from uuid import uuid4 | |
| import logging | |
| logging.getLogger("httpx").setLevel(logging.WARNING) | |
| logging.getLogger("langchain").setLevel(logging.WARNING) | |
| # Apply nested asyncio to enable nested event loops | |
| nest_asyncio.apply() | |
| # Generate a unique ID for this application instance | |
| unique_id = uuid4().hex[0:8] | |
| VIDEOARCHIVE = "VideoArchiveSearch" | |
| ADOBEHELP = "AdobeHelp" | |
| def set_api_key_if_not_present(key_name, prompt_message=""): | |
| """ | |
| Sets an API key in the environment if it's not already present. | |
| Args: | |
| key_name: Name of the environment variable to set | |
| prompt_message: Custom prompt message for getpass (defaults to key_name) | |
| """ | |
| if len(prompt_message) == 0: | |
| prompt_message = key_name | |
| if key_name not in os.environ or not os.environ[key_name]: | |
| os.environ[key_name] = getpass.getpass(prompt_message) | |
| class ApplicationState: | |
| """ | |
| Maintains the state of the application and its components. | |
| Attributes: | |
| embeddings: Embeddings model for vector operations | |
| docs: List of loaded documents | |
| qdrant_client: Client for Qdrant vector database | |
| vector_store: Vector store for document retrieval | |
| datastore_manager: Manager for data storage and retrieval | |
| rag_factory: Factory for creating RAG chains | |
| llm: Language model instance | |
| rag_chain: Retrieval-augmented generation chain | |
| ai_graph: Compiled AI agent graph | |
| ai_graph_sketch: State graph for AI agent orchestration | |
| tasks: List of asyncio tasks | |
| hasLoaded: Event to track when loading is complete | |
| pointsLoaded: Number of data points loaded into the database | |
| """ | |
| embeddings: Embeddings = None | |
| docs: List[Document] = [] | |
| qdrant_client = None | |
| vector_store = None | |
| datastore_manager = None | |
| rag = None | |
| llm: BaseChatModel = None | |
| rag_chain: Runnable = None | |
| ai_graph: Runnable = None | |
| ai_graph_sketch = None | |
| tasks: List[asyncio.Task] = [] | |
| hasLoaded: asyncio.Event = asyncio.Event() | |
| pointsLoaded: int = 0 | |
| def __init__(self) -> None: | |
| """ | |
| Initialize the application state and set up environment variables. | |
| """ | |
| load_dotenv() | |
| set_api_key_if_not_present("OPENAI_API_KEY") | |
| set_api_key_if_not_present("TAVILY_API_KEY") | |
| os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
| os.environ["LANGCHAIN_PROJECT"] = ( | |
| f"AIE - MBUDISIC - HF - CERT - {unique_id}" | |
| ) | |
| set_api_key_if_not_present("LANGCHAIN_API_KEY") | |
| # Initialize global application state | |
| app_state = ApplicationState() | |
| params = Configuration() | |
| ai_state = PsTutsTeamState( | |
| messages=[], | |
| team_members=[VIDEOARCHIVE, ADOBEHELP], | |
| next="START", | |
| ) | |
| async def initialize(): | |
| await fill_the_db(app_state) | |
| app_state.ai_graph, app_state.ai_graph_sketch = await build_the_graph( | |
| app_state | |
| ) | |
| def enter_chain(message: str): | |
| """ | |
| Entry point for the agent graph chain. | |
| Transforms a user message into the state format expected by the agent graph. | |
| Args: | |
| message: User's input message | |
| Returns: | |
| Dictionary with the message and team members information | |
| """ | |
| results = { | |
| "messages": [HumanMessage(content=message)], | |
| "team_members": [VIDEOARCHIVE, ADOBEHELP], | |
| } | |
| return results | |
| async def on_chat_start(): | |
| """ | |
| Initializes the application when a new chat session starts. | |
| Sets up the language model, vector database components, and spawns tasks | |
| for database population and graph building. | |
| """ | |
| app_state.llm = ChatOpenAI(model=params.tool_calling_model, temperature=0) | |
| # Use LangChain's built-in HuggingFaceEmbeddings wrapper | |
| app_state.embeddings = HuggingFaceEmbeddings( | |
| model_name=params.embedding_model | |
| ) | |
| app_state.rag = pstuts_rag.rag.RAGChainInstance( | |
| name="deployed", | |
| qdrant_client=app_state.qdrant_client, | |
| llm=app_state.llm, | |
| embeddings=app_state.embeddings, | |
| ) | |
| app_state.tasks.append(asyncio.create_task(initialize())) | |
| def process_response( | |
| response_message: BaseMessage, | |
| ) -> Tuple[str, List[cl.Message]]: | |
| """ | |
| Processes a response from the AI agents. | |
| Extracts the main text and video references from the response, | |
| and creates message elements for displaying video content. | |
| Args: | |
| response: Response object from the AI agent | |
| Returns: | |
| Tuple containing the text response and a list of message elements with video references | |
| """ | |
| streamed_text = f"[_from: {response_message.name}_]\n" | |
| msg_references = [] | |
| if response_message.name == VIDEOARCHIVE: | |
| text, references = pstuts_rag.rag.RAGChainFactory.unpack_references( | |
| str(response_message.content) | |
| ) | |
| streamed_text += text | |
| if len(references) > 0: | |
| references = json.loads(references) | |
| print(references) | |
| for ref in references: | |
| msg_references.append( | |
| cl.Message( | |
| content=( | |
| f"Watch {ref['title']} from timestamp " | |
| f"{round(ref['start'] // 60)}m:{round(ref['start'] % 60)}s" | |
| ), | |
| elements=[ | |
| cl.Video( | |
| name=ref["title"], | |
| url=f"{ref['source']}#t={ref['start']}", | |
| display="side", | |
| ) | |
| ], | |
| ) | |
| ) | |
| else: | |
| streamed_text += str(response_message.content) | |
| # Find all URLs in the content | |
| urls = re.findall( | |
| r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+[/\w\.-]*(?:\?[/\w\.-=&%]*)?", | |
| str(response_message.content), | |
| ) | |
| print(urls) | |
| links = [] | |
| # Create a list of unique URLs | |
| for idx, u in enumerate(list(set(urls))): | |
| url = "https://api.microlink.io" | |
| params = { | |
| "url": u, | |
| "screenshot": True, | |
| } | |
| payload = requests.get(url, params) | |
| if payload: | |
| print(f"Successful screenshot\n{payload.json()}") | |
| links.append( | |
| cl.Image( | |
| name=f"Website {idx} Preview: {u}", | |
| display="side", # Show in the sidebar | |
| url=payload.json()["data"]["screenshot"]["url"], | |
| ) | |
| ) | |
| print(links) | |
| msg_references.append( | |
| cl.Message( | |
| content="\n".join([l.url for l in links]), elements=links | |
| ) | |
| ) | |
| return streamed_text, msg_references | |
| async def main(user_cl_message: cl.Message): | |
| """ | |
| Processes incoming user messages and sends responses. | |
| Streams the AI agent's response, processes it to extract text and video references, | |
| and sends the content back to the user's chat interface. | |
| Args: | |
| message: User's input message | |
| """ | |
| for s in app_state.ai_graph.stream( | |
| user_cl_message.content, {"recursion_limit": 20} | |
| ): | |
| if "__end__" not in s and "supervisor" not in s.keys(): | |
| for [node_type, node_response] in s.items(): | |
| print(f"Processing {node_type} messages") | |
| for node_message in node_response["messages"]: | |
| print(f"Message {node_message}") | |
| msg = cl.Message(content="") | |
| text, references = process_response(node_message) | |
| for token in [char for char in text]: | |
| await msg.stream_token(token) | |
| await msg.send() | |
| for m in references: | |
| await m.send() | |
| if __name__ == "__main__": | |
| main() | |