mbudisic commited on
Commit
d3e86a1
·
1 Parent(s): 93d0e0d

starting work on agentic app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import os
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, List
6
+
7
+ import chainlit as cl
8
+ from dotenv import load_dotenv
9
+ from langchain_core.documents import Document
10
+ from langchain_core.language_models import BaseChatModel
11
+ from langchain_core.runnables import Runnable
12
+ from langchain_openai import ChatOpenAI
13
+ from langchain_openai.embeddings import OpenAIEmbeddings
14
+ from langchain_qdrant import QdrantVectorStore
15
+ from qdrant_client import QdrantClient
16
+
17
+ import pstuts_rag.datastore
18
+ import pstuts_rag.rag
19
+ from pstuts_rag.loader import load_json_files
20
+
21
+
22
+ @dataclass
23
+ class ApplicationParameters:
24
+ filename = [f"data/{f}.json" for f in ["dev"]]
25
+ embedding_model = "text-embedding-3-small"
26
+ n_context_docs = 2
27
+ llm_model = "gpt-4.1-mini"
28
+
29
+
30
+ def set_api_key_if_not_present(key_name, prompt_message=""):
31
+ if len(prompt_message) == 0:
32
+ prompt_message = key_name
33
+ if key_name not in os.environ or not os.environ[key_name]:
34
+ os.environ[key_name] = getpass.getpass(prompt_message)
35
+
36
+
37
+ class ApplicationState:
38
+ embeddings: OpenAIEmbeddings = None
39
+ docs: List[Document] = []
40
+ qdrant_client: QdrantClient = None
41
+ vector_store: QdrantVectorStore = None
42
+ datastore_manager: pstuts_rag.datastore.DatastoreManager
43
+ rag_factory: pstuts_rag.rag.RAGChainFactory
44
+ llm: BaseChatModel
45
+ rag_chain: Runnable
46
+
47
+ hasLoaded: asyncio.Event = asyncio.Event()
48
+ pointsLoaded: int = 0
49
+
50
+ def __init__(self) -> None:
51
+ load_dotenv()
52
+ set_api_key_if_not_present("OPENAI_API_KEY")
53
+
54
+
55
+ state = ApplicationState()
56
+ params = ApplicationParameters()
57
+
58
+
59
+ async def fill_the_db():
60
+ if state.datastore_manager.count_docs() == 0:
61
+ data: List[Dict[str, Any]] = await load_json_files(params.filename)
62
+ state.pointsLoaded = await state.datastore_manager.populate_database(
63
+ raw_docs=data
64
+ )
65
+ await cl.Message(
66
+ content=f"✅ The database has been loaded with {state.pointsLoaded} elements!"
67
+ ).send()
68
+
69
+
70
+ async def build_the_chain():
71
+ state.rag_factory = pstuts_rag.rag.RAGChainFactory(
72
+ retriever=state.datastore_manager.get_retriever()
73
+ )
74
+ state.llm = ChatOpenAI(model=params.llm_model, temperature=0)
75
+ state.rag_chain = state.rag_factory.get_rag_chain(state.llm)
76
+ pass
77
+
78
+
79
+ @cl.on_chat_start
80
+ async def on_chat_start():
81
+ state.qdrant_client = QdrantClient(":memory:")
82
+
83
+ state.datastore_manager = pstuts_rag.datastore.DatastoreManager(
84
+ qdrant_client=state.qdrant_client, name="local_test"
85
+ )
86
+ asyncio.run(main=fill_the_db())
87
+ asyncio.run(main=build_the_chain())
88
+
89
+
90
+ @cl.on_message
91
+ async def main(message: cl.Message):
92
+ # Send a response back to the user
93
+ msg = cl.Message(content="")
94
+ response = await state.rag_chain.ainvoke({"question": message.content})
95
+
96
+ text, references = pstuts_rag.rag.RAGChainFactory.unpack_references(
97
+ response.content
98
+ )
99
+ if isinstance(text, str):
100
+ for token in [char for char in text]:
101
+ await msg.stream_token(token)
102
+
103
+ await msg.send()
104
+
105
+ references = json.loads(references)
106
+ print(references)
107
+
108
+ msg_references = [
109
+ (
110
+ f"Watch {ref["title"]} from timestamp "
111
+ f"{round(ref["start"] // 60)}m:{round(ref["start"] % 60)}s",
112
+ cl.Video(
113
+ name=ref["title"],
114
+ url=f"{ref["source"]}#t={ref["start"]}",
115
+ display="side",
116
+ ),
117
+ )
118
+ for ref in references
119
+ ]
120
+ await cl.Message(content="Related videos").send()
121
+ for e in msg_references:
122
+ await cl.Message(content=e[0], elements=[e[1]]).send()
123
+
124
+
125
+ if __name__ == "__main__":
126
+ main()