Spaces:
Sleeping
Sleeping
| import asyncio | |
| from typing import List | |
| import chainlit as cl | |
| import json | |
| import os | |
| from dotenv import load_dotenv | |
| from langchain_experimental.text_splitter import SemanticChunker | |
| from langchain_openai import ChatOpenAI | |
| from langchain_openai.embeddings import OpenAIEmbeddings | |
| from langchain_core.documents import Document | |
| from langchain_qdrant import QdrantVectorStore | |
| from pstuts_rag.rag import RAGChainFactory, RetrieverFactory | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.models import Distance, VectorParams | |
| from dataclasses import dataclass | |
| import pstuts_rag.rag | |
| class ApplicationParameters: | |
| filename = "data/test.json" | |
| embedding_model = "text-embedding-3-small" | |
| n_context_docs = 2 | |
| llm_model = "gpt-4.1-mini" | |
| def set_api_key_if_not_present(key_name, prompt_message=""): | |
| if len(prompt_message) == 0: | |
| prompt_message = key_name | |
| if key_name not in os.environ or not os.environ[key_name]: | |
| os.environ[key_name] = getpass.getpass(prompt_message) | |
| class ApplicationState: | |
| embeddings: OpenAIEmbeddings = None | |
| docs: List[Document] = [] | |
| qdrantclient: QdrantClient = None | |
| vectorstore: QdrantVectorStore = None | |
| retriever_factory: pstuts_rag.rag.RetrieverFactory | |
| rag_factory: pstuts_rag.rag.RAGChainFactory | |
| def __init__(self) -> None: | |
| load_dotenv() | |
| set_api_key_if_not_present("OPENAI_API_KEY") | |
| state = ApplicationState() | |
| params = ApplicationParameters() | |
| async def on_chat_start(): | |
| state.client = QdrantClient(":memory:") | |
| state.retriever_factory = pstuts_rag.rag.RetrieverFactory( | |
| qdrant_client=state.client, name="local_test" | |
| ) | |
| if state.retriever_factory.count_docs() == 0: | |
| data: List[Dict[str, Any]] = json.load(open(params.filename, "rb")) | |
| asyncio.run(main=state.retriever_factory.aadd_docs(raw_docs=data)) | |
| state.rag_factory = pstuts_rag.rag.RAGChainFactory( | |
| retriever=state.retriever_factory.get_retriever() | |
| ) | |
| state.llm = ChatOpenAI(model=params.llm_model, temperature=0) | |
| state.rag_chain = state.rag_factory.get_rag_chain(state.llm) | |
| async def main(message: cl.Message): | |
| # Send a response back to the user | |
| v = await state.rag_chain.ainvoke(message.content) | |
| await cl.Message(content=v.content).send() | |
| if __name__ == "__main__": | |
| main() | |