Cheselle commited on
Commit
b085e51
1 Parent(s): ffd0213

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -116
app.py CHANGED
@@ -1,138 +1,123 @@
 
 
1
  from langchain_openai import ChatOpenAI
 
 
2
  from langchain.prompts import ChatPromptTemplate
 
3
  from langchain.schema import StrOutputParser
4
- from langchain.schema.runnable import Runnable
5
- from langchain.schema.runnable.config import RunnableConfig
6
- from dotenv import load_dotenv
7
- import os
8
  from langchain_community.document_loaders import PyMuPDFLoader
9
- from langchain.text_splitter import RecursiveCharacterTextSplitter
10
  from langchain_community.vectorstores import Qdrant
 
 
 
 
 
 
 
11
  import chainlit as cl
 
12
  from sentence_transformers import SentenceTransformer
13
 
14
  # Load environment variables
15
  load_dotenv()
16
- os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
17
 
18
- class SentenceTransformerEmbedding:
19
- def __init__(self, model_name):
 
20
  self.model = SentenceTransformer(model_name)
21
 
22
- def embed_documents(self, texts):
23
- return self.model.encode(texts, convert_to_tensor=True).tolist() # Convert to list for compatibility
 
24
 
25
- def __call__(self, texts):
26
- return self.embed_documents(texts) # Make it callable
 
27
 
28
- @cl.on_chat_start
29
- async def on_chat_start():
30
- model = ChatOpenAI(streaming=True)
31
 
32
- # Load documents
33
- ai_framework_document = PyMuPDFLoader(file_path="https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf").load()
34
- ai_blueprint_document = PyMuPDFLoader(file_path="https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf").load()
35
 
36
- print("Documents loaded.")
37
-
38
- RAG_PROMPT = """\
39
- Given a provided context and question, you must answer the question based only on context.
40
-
41
- Context: {context}
42
- Question: {question}
43
- """
44
-
45
- rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
46
-
47
- sentence_text_splitter = RecursiveCharacterTextSplitter(
48
  chunk_size=500,
49
  chunk_overlap=100,
50
  separators=["\n\n", "\n", ".", "!", "?"]
51
  )
52
-
53
- def metadata_generator(document, name, splitter):
54
- collection = splitter.split_documents(document)
55
- for doc in collection:
56
- doc.metadata["source"] = name
57
- return collection
58
-
59
- sentence_framework = metadata_generator(ai_framework_document, "AI Framework", sentence_text_splitter)
60
- sentence_blueprint = metadata_generator(ai_blueprint_document, "AI Blueprint", sentence_text_splitter)
61
-
62
- sentence_combined_documents = sentence_framework + sentence_blueprint
63
-
64
- print(f"Total documents to embed: {len(sentence_combined_documents)}")
65
-
66
- # Limit the number of documents processed for debugging
67
- max_documents = 10
68
- sentence_combined_documents = sentence_combined_documents[:max_documents]
69
-
70
- # Initialize the embedding model instance
71
- embedding_model = SentenceTransformerEmbedding('Cheselle/finetuned-arctic-sentence')
72
-
73
- # Create the Qdrant vector store using the embedding instance
74
- sentence_vectorstore = Qdrant.from_documents(
75
- documents=sentence_combined_documents,
76
- embedding=embedding_model,
77
- location=":memory:",
78
- collection_name="AI Policy"
79
- )
80
-
81
- print("Vector store created.")
82
-
83
- # Create retriever from the vector store
84
- sentence_retriever = sentence_vectorstore.as_retriever()
85
-
86
- # Check if retriever is initialized correctly
87
- if sentence_retriever is None:
88
- raise ValueError("Retriever is not initialized correctly.")
89
-
90
- # Set the retriever and prompt into session for reuse
91
- cl.user_session.set("runnable", model)
92
- cl.user_session.set("retriever", sentence_retriever)
93
- cl.user_session.set("prompt_template", rag_prompt)
94
-
 
 
 
 
95
  @cl.on_message
96
- async def on_message(message: cl.Message):
97
- # Get the stored model, retriever, and prompt
98
- model = cl.user_session.get("runnable")
99
- retriever = cl.user_session.get("retriever")
100
- prompt_template = cl.user_session.get("prompt_template")
101
-
102
- # Log the message content
103
- print(f"Received message: {message.content}")
104
-
105
- # Retrieve relevant context from documents based on the user's message
106
- if retriever is None:
107
- print("Retriever is not available.")
108
- await cl.Message(content="Sorry, the retriever is not initialized.").send()
109
- return
110
-
111
- relevant_docs = retriever.get_relevant_documents(message.content)
112
- print(f"Retrieved {len(relevant_docs)} documents.")
113
-
114
- if not relevant_docs:
115
- print("No relevant documents found.")
116
- await cl.Message(content="Sorry, I couldn't find any relevant documents.").send()
117
- return
118
-
119
- context = "\n\n".join([doc.page_content for doc in relevant_docs])
120
-
121
- # Log the context to check
122
- print(f"Context: {context}")
123
-
124
- # Construct the final RAG prompt
125
- final_prompt = prompt_template.format(context=context, question=message.content)
126
- print(f"Final prompt: {final_prompt}")
127
-
128
- # Initialize a streaming message
129
- msg = cl.Message(content="")
130
-
131
- # Stream the response from the model
132
- async for chunk in model.astream(
133
- final_prompt,
134
- config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
135
- ):
136
- await msg.stream_token(chunk.content)
137
-
138
- await msg.send()
 
1
+ import re
2
+ from langchain_openai import OpenAIEmbeddings
3
  from langchain_openai import ChatOpenAI
4
+ from langchain_openai.embeddings import OpenAIEmbeddings
5
+
6
  from langchain.prompts import ChatPromptTemplate
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
  from langchain.schema import StrOutputParser
9
+
 
 
 
10
  from langchain_community.document_loaders import PyMuPDFLoader
 
11
  from langchain_community.vectorstores import Qdrant
12
+
13
+ from langchain_core.runnables import RunnablePassthrough, RunnableParallel
14
+ from langchain_core.documents import Document
15
+
16
+ from operator import itemgetter
17
+ import os
18
+ from dotenv import load_dotenv
19
  import chainlit as cl
20
+ from langchain.embeddings.base import Embeddings
21
  from sentence_transformers import SentenceTransformer
22
 
23
  # Load environment variables
24
  load_dotenv()
 
25
 
26
+ # Custom wrapper for SentenceTransformer to work with Langchain
27
+ class LangchainSentenceTransformerEmbeddings(Embeddings):
28
+ def __init__(self, model_name: str):
29
  self.model = SentenceTransformer(model_name)
30
 
31
+ def embed_documents(self, texts: list[str]) -> list[list[float]]:
32
+ # Encode the documents using SentenceTransformer's encode method
33
+ return self.model.encode(texts)
34
 
35
+ def embed_query(self, text: str) -> list[float]:
36
+ # Encode a single query using SentenceTransformer's encode method
37
+ return self.model.encode([text])[0]
38
 
39
+ # Initialize the custom embedding model
40
+ embedding_model = LangchainSentenceTransformerEmbeddings("Cheselle/finetuned-arctic-sentence")
 
41
 
42
+ # Load the documents using PyMuPDFLoader
43
+ ai_framework_document = PyMuPDFLoader(file_path="https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf").load()
44
+ ai_blueprint_document = PyMuPDFLoader(file_path="https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf").load()
45
 
46
+ # Metadata generator function to add metadata to documents
47
+ def metadata_generator(document, name):
48
+ fixed_text_splitter = RecursiveCharacterTextSplitter(
 
 
 
 
 
 
 
 
 
49
  chunk_size=500,
50
  chunk_overlap=100,
51
  separators=["\n\n", "\n", ".", "!", "?"]
52
  )
53
+ collection = fixed_text_splitter.split_documents(document)
54
+ for doc in collection:
55
+ doc.metadata["source"] = name
56
+ return collection
57
+
58
+ # Generate metadata for the loaded documents
59
+ recursive_framework_document = metadata_generator(ai_framework_document, "AI Framework")
60
+ recursive_blueprint_document = metadata_generator(ai_blueprint_document, "AI Blueprint")
61
+ combined_documents = recursive_framework_document + recursive_blueprint_document
62
+
63
+ # Combine the content of the documents
64
+ ai_framework_text = "".join([doc.page_content for doc in ai_framework_document])
65
+ ai_blueprint_text = "".join([doc.page_content for doc in ai_blueprint_document])
66
+
67
+ # Pass the custom embedding model to Qdrant to create a vectorstore
68
+ vectorstore = Qdrant.from_documents(
69
+ documents=combined_documents, # List of documents
70
+ embedding=embedding_model, # Custom Langchain wrapper for SentenceTransformer
71
+ location=":memory:",
72
+ collection_name="ai_policy"
73
+ )
74
+
75
+ # Set up the retriever
76
+ retriever = vectorstore.as_retriever()
77
+
78
+ # LLM configuration
79
+ llm = ChatOpenAI(model="gpt-4o-mini")
80
+
81
+ # Define the RAG (Retrieval-Augmented Generation) prompt template
82
+ RAG_PROMPT = """\
83
+ You are an AI Policy Expert.
84
+ Given a provided context and question, you must answer the question based only on context.
85
+ Think through your answer carefully and step by step.
86
+ Context: {context}
87
+ Question: {question}
88
+ """
89
+
90
+ rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
91
+
92
+ # Define the retrieval-augmented QA chain
93
+ retrieval_augmented_qa_chain = (
94
+ {"context": itemgetter("question") | retriever, "question": itemgetter("question")}
95
+ | RunnablePassthrough.assign(context=itemgetter("context"))
96
+ | {"response": rag_prompt | llm, "context": itemgetter("context")}
97
+ )
98
+
99
+ # Chainlit event handler for receiving messages
100
  @cl.on_message
101
+ async def handle_message(message):
102
+ try:
103
+ # Process the incoming question using the RAG chain
104
+ result = retrieval_augmented_qa_chain.invoke({"question": message.content})
105
+
106
+ # Create a new message for the response
107
+ response_message = cl.Message(content=result["response"].content)
108
+
109
+ # Send the response back to the user
110
+ await response_message.send()
111
+
112
+ except Exception as e:
113
+ # Handle any exception and log it or send a response back to the user
114
+ error_message = cl.Message(content=f"An error occurred: {str(e)}")
115
+ await error_message.send()
116
+ print(f"Error occurred: {e}")
117
+
118
+ # Run the Chainlit server
119
+ if __name__ == "__main__":
120
+ try:
121
+ cl.run()
122
+ except Exception as e:
123
+ print(f"Server error occurred: {e}")