0504ankitsharma commited on
Commit
c47212f
β€’
1 Parent(s): 0876a19

Upload 9 files

Browse files
Files changed (10) hide show
  1. .gitattributes +1 -0
  2. Dockerfile +16 -0
  3. README.md +1 -10
  4. data/Data.pdf +3 -0
  5. get_embedding_function.py +10 -0
  6. populate_database.py +109 -0
  7. query_data.py +54 -0
  8. requirements.txt +9 -0
  9. server.py +176 -0
  10. test_rag.py +49 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/Data.pdf filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.9
5
+
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+ ENV PATH="/home/user/.local/bin:$PATH"
9
+
10
+ WORKDIR /app
11
+
12
+ COPY --chown=user ./requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+
15
+ COPY --chown=user . /app
16
+ CMD ["uvicorn", "app:server", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1 @@
1
- ---
2
- title: ThaparGPT
3
- emoji: πŸ‘
4
- colorFrom: green
5
- colorTo: blue
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # rag-tutorial-v2
 
 
 
 
 
 
 
 
 
data/Data.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ef945caf75b8219067ce06bd625f8581c60c54d58d071ef8355d9cba9294d84
3
+ size 1378767
get_embedding_function.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.embeddings.ollama import OllamaEmbeddings
2
+ from langchain_community.embeddings.bedrock import BedrockEmbeddings
3
+
4
+ #
5
+ def get_embedding_function():
6
+ # embeddings = BedrockEmbeddings(
7
+ # credentials_profile_name="default", region_name="us-east-1"
8
+ # )
9
+ embeddings = OllamaEmbeddings(model="nomic-embed-text")
10
+ return embeddings
populate_database.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import shutil
4
+ from langchain_community.document_loaders.pdf import PyPDFDirectoryLoader
5
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
6
+ from langchain.schema.document import Document
7
+ from get_embedding_function import get_embedding_function
8
+ from langchain_community.vectorstores import Chroma
9
+
10
+ CHROMA_PATH = "chroma"
11
+ DATA_PATH = "data"
12
+
13
+
14
+ def main():
15
+
16
+ # Check if the database should be cleared (using the --clear flag).
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--reset", action="store_true", help="Reset the database.")
19
+ args = parser.parse_args()
20
+ if args.reset:
21
+ print("✨ Clearing Database")
22
+ clear_database()
23
+
24
+ # Create (or update) the data store.
25
+ documents = load_documents()
26
+ chunks = split_documents(documents)
27
+ add_to_chroma(chunks)
28
+
29
+
30
+ def load_documents():
31
+ document_loader = PyPDFDirectoryLoader(DATA_PATH)
32
+ return document_loader.load()
33
+
34
+
35
+ def split_documents(documents: list[Document]):
36
+ text_splitter = RecursiveCharacterTextSplitter(
37
+ chunk_size=800,
38
+ chunk_overlap=80,
39
+ length_function=len,
40
+ is_separator_regex=False,
41
+ )
42
+ return text_splitter.split_documents(documents)
43
+
44
+
45
+ def add_to_chroma(chunks: list[Document]):
46
+ # Load the existing database.
47
+ db = Chroma(
48
+ persist_directory=CHROMA_PATH, embedding_function=get_embedding_function()
49
+ )
50
+
51
+ # Calculate Page IDs.
52
+ chunks_with_ids = calculate_chunk_ids(chunks)
53
+
54
+ # Add or Update the documents.
55
+ existing_items = db.get(include=[]) # IDs are always included by default
56
+ existing_ids = set(existing_items["ids"])
57
+ print(f"Number of existing documents in DB: {len(existing_ids)}")
58
+
59
+ # Only add documents that don't exist in the DB.
60
+ new_chunks = []
61
+ for chunk in chunks_with_ids:
62
+ if chunk.metadata["id"] not in existing_ids:
63
+ new_chunks.append(chunk)
64
+
65
+ if len(new_chunks):
66
+ print(f"πŸ‘‰ Adding new documents: {len(new_chunks)}")
67
+ new_chunk_ids = [chunk.metadata["id"] for chunk in new_chunks]
68
+ db.add_documents(new_chunks, ids=new_chunk_ids)
69
+ db.persist()
70
+ else:
71
+ print("βœ… No new documents to add")
72
+
73
+
74
+ def calculate_chunk_ids(chunks):
75
+
76
+ # This will create IDs like "data/monopoly.pdf:6:2"
77
+ # Page Source : Page Number : Chunk Index
78
+
79
+ last_page_id = None
80
+ current_chunk_index = 0
81
+
82
+ for chunk in chunks:
83
+ source = chunk.metadata.get("source")
84
+ page = chunk.metadata.get("page")
85
+ current_page_id = f"{source}:{page}"
86
+
87
+ # If the page ID is the same as the last one, increment the index.
88
+ if current_page_id == last_page_id:
89
+ current_chunk_index += 1
90
+ else:
91
+ current_chunk_index = 0
92
+
93
+ # Calculate the chunk ID.
94
+ chunk_id = f"{current_page_id}:{current_chunk_index}"
95
+ last_page_id = current_page_id
96
+
97
+ # Add it to the page meta-data.
98
+ chunk.metadata["id"] = chunk_id
99
+
100
+ return chunks
101
+
102
+
103
+ def clear_database():
104
+ if os.path.exists(CHROMA_PATH):
105
+ shutil.rmtree(CHROMA_PATH)
106
+
107
+
108
+ if __name__ == "__main__":
109
+ main()
query_data.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from langchain_community.vectorstores import Chroma
3
+ from langchain.prompts import ChatPromptTemplate
4
+ from langchain_community.llms.ollama import Ollama
5
+
6
+ from get_embedding_function import get_embedding_function
7
+
8
+ CHROMA_PATH = "chroma"
9
+
10
+ PROMPT_TEMPLATE = """
11
+ Answer the question based only on the following context:
12
+
13
+ {context}
14
+
15
+ ---
16
+
17
+ Answer the question based on the above context: {question}
18
+ """
19
+
20
+
21
+ def main():
22
+ # Create CLI.
23
+ # parser = argparse.ArgumentParser()
24
+ # parser.add_argument("query_text", type=str, help="The query text.")
25
+ # args = parser.parse_args()
26
+ # query_text = args.query_text
27
+ # query_rag(query_text)
28
+ query_rag(input( "Enter your query: "))
29
+
30
+
31
+ def query_rag(query_text: str):
32
+ # Prepare the DB.
33
+ embedding_function = get_embedding_function()
34
+ db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function)
35
+
36
+ # Search the DB.
37
+ results = db.similarity_search_with_score(query_text, k=5)
38
+
39
+ context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results])
40
+ prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
41
+ prompt = prompt_template.format(context=context_text, question=query_text)
42
+ # print(prompt)
43
+
44
+ model = Ollama(model="mistral")
45
+ response_text = model.invoke(prompt)
46
+
47
+ sources = [doc.metadata.get("id", None) for doc, _score in results]
48
+ formatted_response = f"Response: {response_text}\nSources: {sources}"
49
+ print(formatted_response)
50
+ return response_text
51
+
52
+
53
+ if __name__ == "__main__":
54
+ main()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ pypdf
2
+ langchain
3
+ chromadb # Vector storage
4
+ pytest
5
+ boto3
6
+ langchain_community
7
+ pyyaml
8
+ fastapi
9
+ uvicorn[standard]
server.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CHROMA_PATH = "chroma"
2
+ DATA_PATH = "data"
3
+
4
+ from fastapi import FastAPI
5
+ import argparse
6
+ import os
7
+ import shutil
8
+ from langchain_community.document_loaders.pdf import PyPDFDirectoryLoader
9
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
10
+ from langchain.schema.document import Document
11
+ from get_embedding_function import get_embedding_function
12
+ from langchain_community.vectorstores import Chroma
13
+ from langchain.prompts import ChatPromptTemplate
14
+ from langchain_community.llms.ollama import Ollama
15
+
16
+ PROMPT_TEMPLATE = """
17
+ Answer the question based only on the following context:
18
+
19
+ {context}
20
+
21
+ ---
22
+
23
+ Answer the question based on the above context: {question}
24
+ """
25
+
26
+ app = FastAPI()
27
+
28
+ from langchain_community.embeddings.ollama import OllamaEmbeddings
29
+ from langchain_community.embeddings.bedrock import BedrockEmbeddings
30
+
31
+ #
32
+ def get_embedding_function():
33
+ # embeddings = BedrockEmbeddings(
34
+ # credentials_profile_name="default", region_name="us-east-1"
35
+ # )
36
+ embeddings = OllamaEmbeddings(model="nomic-embed-text")
37
+ return embeddings
38
+
39
+
40
+ @app.get("/")
41
+ def greet_json():
42
+ return {"Hello": "World!"}
43
+
44
+ @app.get("/train")
45
+ def train():
46
+
47
+ # Check if the database should be cleared (using the --clear flag).
48
+ parser = argparse.ArgumentParser()
49
+ parser.add_argument("--reset", action="store_true", help="Reset the database.")
50
+ args = parser.parse_args()
51
+ if args.reset:
52
+ print("✨ Clearing Database")
53
+ clear_database()
54
+
55
+ # Create (or update) the data store.
56
+ documents = load_documents()
57
+ chunks = split_documents(documents)
58
+ add_to_chroma(chunks)
59
+
60
+
61
+ def load_documents():
62
+ document_loader = PyPDFDirectoryLoader(DATA_PATH)
63
+ return document_loader.load()
64
+
65
+
66
+ def split_documents(documents: list[Document]):
67
+ text_splitter = RecursiveCharacterTextSplitter(
68
+ chunk_size=800,
69
+ chunk_overlap=80,
70
+ length_function=len,
71
+ is_separator_regex=False,
72
+ )
73
+ return text_splitter.split_documents(documents)
74
+
75
+
76
+ def add_to_chroma(chunks: list[Document]):
77
+ # Load the existing database.
78
+ db = Chroma(
79
+ persist_directory=CHROMA_PATH, embedding_function=get_embedding_function()
80
+ )
81
+
82
+ # Calculate Page IDs.
83
+ chunks_with_ids = calculate_chunk_ids(chunks)
84
+
85
+ # Add or Update the documents.
86
+ existing_items = db.get(include=[]) # IDs are always included by default
87
+ existing_ids = set(existing_items["ids"])
88
+ print(f"Number of existing documents in DB: {len(existing_ids)}")
89
+
90
+ # Only add documents that don't exist in the DB.
91
+ new_chunks = []
92
+ for chunk in chunks_with_ids:
93
+ if chunk.metadata["id"] not in existing_ids:
94
+ new_chunks.append(chunk)
95
+
96
+ if len(new_chunks):
97
+ print(f"πŸ‘‰ Adding new documents: {len(new_chunks)}")
98
+ new_chunk_ids = [chunk.metadata["id"] for chunk in new_chunks]
99
+ db.add_documents(new_chunks, ids=new_chunk_ids)
100
+ db.persist()
101
+ else:
102
+ print("βœ… No new documents to add")
103
+
104
+
105
+ def calculate_chunk_ids(chunks):
106
+
107
+ # This will create IDs like "data/monopoly.pdf:6:2"
108
+ # Page Source : Page Number : Chunk Index
109
+
110
+ last_page_id = None
111
+ current_chunk_index = 0
112
+
113
+ for chunk in chunks:
114
+ source = chunk.metadata.get("source")
115
+ page = chunk.metadata.get("page")
116
+ current_page_id = f"{source}:{page}"
117
+
118
+ # If the page ID is the same as the last one, increment the index.
119
+ if current_page_id == last_page_id:
120
+ current_chunk_index += 1
121
+ else:
122
+ current_chunk_index = 0
123
+
124
+ # Calculate the chunk ID.
125
+ chunk_id = f"{current_page_id}:{current_chunk_index}"
126
+ last_page_id = current_page_id
127
+
128
+ # Add it to the page meta-data.
129
+ chunk.metadata["id"] = chunk_id
130
+
131
+ return chunks
132
+
133
+
134
+ def clear_database():
135
+ if os.path.exists(CHROMA_PATH):
136
+ shutil.rmtree(CHROMA_PATH)
137
+
138
+ return {""}
139
+
140
+
141
+ @app.get("/query")
142
+ def query(query_text: str):
143
+ # Prepare the DB.
144
+ embedding_function = get_embedding_function()
145
+ db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function)
146
+
147
+ # Search the DB.
148
+ results = db.similarity_search_with_score(query_text, k=5)
149
+
150
+ context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results])
151
+ prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
152
+ prompt = prompt_template.format(context=context_text, question=query_text)
153
+ # print(prompt)
154
+
155
+ model = Ollama(model="mistral")
156
+ response_text = model.invoke(prompt)
157
+
158
+ sources = [doc.metadata.get("id", None) for doc, _score in results]
159
+ formatted_response = f"Response: {response_text}\nSources: {sources}"
160
+ print(formatted_response)
161
+ return response_text
162
+
163
+
164
+
165
+
166
+
167
+
168
+
169
+
170
+
171
+
172
+
173
+
174
+
175
+
176
+
test_rag.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from query_data import query_rag
2
+ from langchain_community.llms.ollama import Ollama
3
+
4
+ EVAL_PROMPT = """
5
+ Expected Response: {expected_response}
6
+ Actual Response: {actual_response}
7
+ ---
8
+ (Answer with 'true' or 'false') Does the actual response match the expected response?
9
+ """
10
+
11
+
12
+ def test_monopoly_rules():
13
+ assert query_and_validate(
14
+ question="How much total money does a player start with in Monopoly? (Answer with the number only)",
15
+ expected_response="$1500",
16
+ )
17
+
18
+
19
+ def test_ticket_to_ride_rules():
20
+ assert query_and_validate(
21
+ question="How many points does the longest continuous train get in Ticket to Ride? (Answer with the number only)",
22
+ expected_response="10 points",
23
+ )
24
+
25
+
26
+ def query_and_validate(question: str, expected_response: str):
27
+ response_text = query_rag(question)
28
+ prompt = EVAL_PROMPT.format(
29
+ expected_response=expected_response, actual_response=response_text
30
+ )
31
+
32
+ model = Ollama(model="mistral")
33
+ evaluation_results_str = model.invoke(prompt)
34
+ evaluation_results_str_cleaned = evaluation_results_str.strip().lower()
35
+
36
+ print(prompt)
37
+
38
+ if "true" in evaluation_results_str_cleaned:
39
+ # Print response in Green if it is correct.
40
+ print("\033[92m" + f"Response: {evaluation_results_str_cleaned}" + "\033[0m")
41
+ return True
42
+ elif "false" in evaluation_results_str_cleaned:
43
+ # Print response in Red if it is incorrect.
44
+ print("\033[91m" + f"Response: {evaluation_results_str_cleaned}" + "\033[0m")
45
+ return False
46
+ else:
47
+ raise ValueError(
48
+ f"Invalid evaluation result. Cannot determine if 'true' or 'false'."
49
+ )