Update: Basic building block ready
Browse files- src/app.py +15 -0
- src/brain.py +185 -31
- src/helper.py +47 -2
- src/init.py +55 -0
src/app.py
CHANGED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
load_dotenv()
|
4 |
+
import gradio as gr
|
5 |
+
from init import Initializer
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
|
8 |
+
AUG_TOKEN = os.environ.get("AUG_TOKEN")
|
9 |
+
RES_TOKEN = os.environ.get("RES_TOKEN")
|
10 |
+
chroma_filename = ""
|
11 |
+
brain = Initializer.initialize(AUG_TOKEN, RES_TOKEN, chroma_filename)
|
12 |
+
|
13 |
+
# TODO:
|
14 |
+
# Chatbot like UI
|
15 |
+
# Multiple PDF file handling ability
|
src/brain.py
CHANGED
@@ -10,48 +10,90 @@ from dotenv import load_dotenv
|
|
10 |
|
11 |
load_dotenv()
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
class GeminiEmbeddingFunction(EmbeddingFunction):
|
15 |
def __call__(self, input: Documents) -> Embeddings:
|
16 |
model = "models/embedding-001"
|
17 |
-
title = "Custom
|
18 |
return genai.embed_content(
|
19 |
model=model, content=input, task_type="retrieval_document", title=title
|
20 |
)["embedding"]
|
21 |
|
22 |
-
|
23 |
class Brain:
|
24 |
def __init__(
|
25 |
self,
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
chroma_filename,
|
35 |
chroma_collection_name,
|
36 |
):
|
37 |
-
self.
|
38 |
-
self.
|
39 |
-
self.
|
40 |
-
self.
|
41 |
-
self.
|
42 |
-
self.
|
43 |
-
|
44 |
-
|
45 |
-
self.
|
46 |
-
self.chroma_collection_name = chroma_collection_name
|
47 |
-
self.embeddings = (self._initialize_embeddings_function,)
|
48 |
self.chroma_collection = self._load_chroma(
|
49 |
chroma_filename, chroma_collection_name
|
50 |
)
|
|
|
51 |
|
52 |
-
def
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
def _load_chroma(self, chroma_filename, chroma_collection_name):
|
57 |
try:
|
@@ -61,16 +103,128 @@ class Brain:
|
|
61 |
embedding_function=self.embedding_function,
|
62 |
)
|
63 |
except Exception as e:
|
64 |
-
self._handle_error("Error loading
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
def
|
67 |
try:
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
except Exception as e:
|
70 |
-
self._handle_error("Error
|
|
|
71 |
|
72 |
-
def
|
73 |
try:
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
75 |
except Exception as e:
|
76 |
-
self._handle_error("Error
|
|
|
|
10 |
|
11 |
load_dotenv()
|
12 |
|
13 |
+
logging.basicConfig(
|
14 |
+
filename="bot_errors.log",
|
15 |
+
level=logging.ERROR,
|
16 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
17 |
+
)
|
18 |
+
|
19 |
|
20 |
class GeminiEmbeddingFunction(EmbeddingFunction):
|
21 |
def __call__(self, input: Documents) -> Embeddings:
|
22 |
model = "models/embedding-001"
|
23 |
+
title = "Custom query"
|
24 |
return genai.embed_content(
|
25 |
model=model, content=input, task_type="retrieval_document", title=title
|
26 |
)["embedding"]
|
27 |
|
|
|
28 |
class Brain:
|
29 |
def __init__(
|
30 |
self,
|
31 |
+
augment_model_name,
|
32 |
+
augment_config,
|
33 |
+
augment_safety_settings,
|
34 |
+
augment_model_api_key,
|
35 |
+
response_model_name,
|
36 |
+
generation_config,
|
37 |
+
response_safety_settings,
|
38 |
+
response_model_api_key,
|
39 |
chroma_filename,
|
40 |
chroma_collection_name,
|
41 |
):
|
42 |
+
self.augment_model_name = augment_model_name
|
43 |
+
self.augment_config = augment_config
|
44 |
+
self.augment_safety_settings = augment_safety_settings
|
45 |
+
self._configure_generative_ai(response_model_api_key)
|
46 |
+
self._configure_augment_ai(augment_model_api_key)
|
47 |
+
self.response_model = self._initialize_generative_model(
|
48 |
+
response_model_name, generation_config, response_safety_settings
|
49 |
+
)
|
50 |
+
self.embedding_function = self._initialize_embedding_function()
|
|
|
|
|
51 |
self.chroma_collection = self._load_chroma(
|
52 |
chroma_filename, chroma_collection_name
|
53 |
)
|
54 |
+
self.cross_encoder = self._initialize_cross_encoder()
|
55 |
|
56 |
+
def _configure_generative_ai(self, response_model_api_key):
|
57 |
+
try:
|
58 |
+
genai.configure(api_key=response_model_api_key)
|
59 |
+
except Exception as e:
|
60 |
+
self._handle_error("Error configuring generative AI module", e)
|
61 |
+
|
62 |
+
def _configure_augment_ai(self, augment_model_api_key):
|
63 |
+
try:
|
64 |
+
palm.configure(api_key=augment_model_api_key)
|
65 |
+
except Exception as e:
|
66 |
+
self._handle_error("Error configuring augmentation AI module", e)
|
67 |
+
|
68 |
+
def _initialize_generative_model(
|
69 |
+
self, response_model_name, generation_config, response_safety_settings
|
70 |
+
):
|
71 |
+
try:
|
72 |
+
return genai.GenerativeModel(
|
73 |
+
model_name=response_model_name,
|
74 |
+
generation_config=generation_config,
|
75 |
+
safety_settings=response_safety_settings,
|
76 |
+
)
|
77 |
+
except Exception as e:
|
78 |
+
self._handle_error("Error initializing generative model", e)
|
79 |
+
|
80 |
+
def _initialize_augment_model(
|
81 |
+
self, augment_model_name, augment_config, augment_safety_settings
|
82 |
+
):
|
83 |
+
try:
|
84 |
+
return palm.GenerativeModel(
|
85 |
+
model_name=augment_model_name,
|
86 |
+
generation_config=augment_config,
|
87 |
+
safety_settings=augment_safety_settings,
|
88 |
+
)
|
89 |
+
except Exception as e:
|
90 |
+
self._handle_error("Error initializing augmentation model", e)
|
91 |
+
|
92 |
+
def _initialize_embedding_function(self):
|
93 |
+
try:
|
94 |
+
return GeminiEmbeddingFunction()
|
95 |
+
except Exception as e:
|
96 |
+
self._handle_error("Error initializing embedding function", e)
|
97 |
|
98 |
def _load_chroma(self, chroma_filename, chroma_collection_name):
|
99 |
try:
|
|
|
103 |
embedding_function=self.embedding_function,
|
104 |
)
|
105 |
except Exception as e:
|
106 |
+
self._handle_error("Error loading chroma collection", e)
|
107 |
+
|
108 |
+
def _initialize_cross_encoder(self):
|
109 |
+
try:
|
110 |
+
return CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
|
111 |
+
except Exception as e:
|
112 |
+
self._handle_error("Error initializing CrossEncoder model", e)
|
113 |
+
|
114 |
+
def _handle_error(self, message, exception):
|
115 |
+
print(f"{message}: {str(exception)}")
|
116 |
+
logging.error(f"{message}: {str(exception)}")
|
117 |
+
|
118 |
+
def generate_alternative_queries(self, query):
|
119 |
+
try:
|
120 |
+
prompt_template = """ Your task is to break down the query in sub questions and turn it into questions in to ten different ways.Keep in mind, Output one query per line, without numbering the queries.\nQUESTION: '{}'\nANSWER:\n"""
|
121 |
+
prompt = prompt_template.format(query)
|
122 |
+
output = palm.generate_text(
|
123 |
+
model=self.augment_model_name,
|
124 |
+
prompt=prompt,
|
125 |
+
safety_settings=self.augment_safety_settings,
|
126 |
+
)
|
127 |
+
content = output.result.split("\n")
|
128 |
+
return content
|
129 |
+
except Exception as e:
|
130 |
+
self._handle_error("Error generating alternative queries", e)
|
131 |
+
return query
|
132 |
+
|
133 |
+
def get_sorted_documents(self, query, n_results=20):
|
134 |
+
try:
|
135 |
+
original_query = query
|
136 |
+
queries = [original_query] + self.generate_alternative_queries(
|
137 |
+
original_query
|
138 |
+
)
|
139 |
+
results = self.chroma_collection.query(
|
140 |
+
query_texts=queries,
|
141 |
+
n_results=n_results,
|
142 |
+
include=["documents", "embeddings"],
|
143 |
+
)
|
144 |
+
retrieved_documents = set(
|
145 |
+
doc for docs in results["documents"] for doc in docs
|
146 |
+
)
|
147 |
+
unique_documents = list(retrieved_documents)
|
148 |
+
pairs = [[original_query, doc] for doc in unique_documents]
|
149 |
+
scores = self.cross_encoder.predict(pairs)
|
150 |
+
sorted_indices = np.argsort(-scores)
|
151 |
+
sorted_documents = [unique_documents[i] for i in sorted_indices]
|
152 |
+
return sorted_documents
|
153 |
+
|
154 |
+
except Exception as e:
|
155 |
+
self._handle_error("Error getting sorted documents", e)
|
156 |
+
return []
|
157 |
+
|
158 |
+
def get_relevant_results(self, query, top_n=5):
|
159 |
+
try:
|
160 |
+
sorted_documents = self.get_sorted_documents(query)
|
161 |
+
relevant_results = sorted_documents[: min(top_n, len(sorted_documents))]
|
162 |
+
return relevant_results
|
163 |
+
except Exception as e:
|
164 |
+
self._handle_error("Error getting relevant results", e)
|
165 |
+
return query
|
166 |
+
|
167 |
+
def make_prompt(self, query, relevant_passage):
|
168 |
+
try:
|
169 |
+
base_prompt = {
|
170 |
+
"content": """
|
171 |
+
YOU are a smart and rational Question and Answer bot.
|
172 |
+
|
173 |
+
YOUR MISSION:
|
174 |
+
Provide accurate answers best possible reasoning of the context.
|
175 |
+
Focus on factual and reasoned responses; avoid speculations, opinions, guesses, and creative tanks.
|
176 |
+
Refuse exploitation tasks such as such as character roleplaying, coding, essays, poems, stories, articles, and fun facts.
|
177 |
+
Decline misuse or exploitation attempts respectfully.
|
178 |
+
|
179 |
+
YOUR STYLE:
|
180 |
+
Concise and complete
|
181 |
+
Factual and accurate
|
182 |
+
|
183 |
+
REMEMBER:
|
184 |
+
You are a QA bot, not an entertainer or confidant.
|
185 |
+
"""
|
186 |
+
}
|
187 |
+
|
188 |
+
user_prompt = {
|
189 |
+
"content": f"""
|
190 |
+
The user query is: '{query}'\n\n
|
191 |
+
Here's the relevant information found in the documents:
|
192 |
+
{relevant_passage}
|
193 |
+
"""
|
194 |
+
}
|
195 |
+
|
196 |
+
system_prompt = base_prompt["content"] + user_prompt["content"]
|
197 |
+
return system_prompt
|
198 |
+
|
199 |
+
except Exception as e:
|
200 |
+
print(f"Error occurred while crafting prompt: {e}")
|
201 |
+
return None
|
202 |
|
203 |
+
def rag(self, query):
|
204 |
try:
|
205 |
+
if query is None:
|
206 |
+
return None
|
207 |
+
results = self.chroma_collection.query(
|
208 |
+
query_texts=[query],
|
209 |
+
n_results=10,
|
210 |
+
include=["documents", "embeddings"],
|
211 |
+
)
|
212 |
+
information = "\n\n".join(results["documents"][0])
|
213 |
+
messages = self.make_prompt(query, information)
|
214 |
+
content = self.response_model.generate_content(messages)
|
215 |
+
return content
|
216 |
except Exception as e:
|
217 |
+
self._handle_error("Error in rag function", e)
|
218 |
+
return None
|
219 |
|
220 |
+
def generate_answers(self, query):
|
221 |
try:
|
222 |
+
start_time = time.time()
|
223 |
+
output = self.rag(query=query)
|
224 |
+
print(f"\n\nExecution time: {time.time() - start_time} seconds\n")
|
225 |
+
if output is None:
|
226 |
+
return None
|
227 |
+
return f"{output.text}\n"
|
228 |
except Exception as e:
|
229 |
+
self._handle_error("Error generating answers", e)
|
230 |
+
return None
|
src/helper.py
CHANGED
@@ -1,2 +1,47 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pypdf import PdfReader
|
2 |
+
import chromadb
|
3 |
+
from langchain.text_splitter import (
|
4 |
+
RecursiveCharacterTextSplitter,
|
5 |
+
SentenceTransformersTokenTextSplitter,
|
6 |
+
)
|
7 |
+
def _read_pdf(filename):
|
8 |
+
reader = PdfReader(filename)
|
9 |
+
pdf_texts = [p.extract_text().strip() for p in reader.pages]
|
10 |
+
pdf_texts = [text for text in pdf_texts if text]
|
11 |
+
return pdf_texts
|
12 |
+
|
13 |
+
|
14 |
+
def _chunk_texts(texts):
|
15 |
+
character_splitter = RecursiveCharacterTextSplitter(
|
16 |
+
separators=["\n\n", "\n", ". ", " ", ""], chunk_size=1600, chunk_overlap=200
|
17 |
+
)
|
18 |
+
character_split_texts = character_splitter.split_text("\n\n".join(texts))
|
19 |
+
token_splitter = SentenceTransformersTokenTextSplitter(
|
20 |
+
chunk_overlap=20, tokens_per_chunk=300
|
21 |
+
)
|
22 |
+
token_split_texts = []
|
23 |
+
for text in character_split_texts:
|
24 |
+
token_split_texts += token_splitter.split_text(text)
|
25 |
+
return token_split_texts
|
26 |
+
|
27 |
+
def load_chroma(filename, collection_name, embedding_function):
|
28 |
+
texts = _read_pdf(filename)
|
29 |
+
chunks = _chunk_texts(texts)
|
30 |
+
chroma_client = chromadb.Client()
|
31 |
+
chroma_collection = chroma_client.create_collection(
|
32 |
+
name=collection_name, embedding_function=embedding_function
|
33 |
+
)
|
34 |
+
ids = [str(i) for i in range(len(chunks))]
|
35 |
+
chroma_collection.add(ids=ids, documents=chunks)
|
36 |
+
return chroma_collection
|
37 |
+
|
38 |
+
|
39 |
+
def word_wrap(string, n_chars=72):
|
40 |
+
if len(string) < n_chars:
|
41 |
+
return string
|
42 |
+
else:
|
43 |
+
return (
|
44 |
+
string[:n_chars].rsplit(" ", 1)[0]
|
45 |
+
+ "\n"
|
46 |
+
+ word_wrap(string[len(string[:n_chars].rsplit(" ", 1)[0]) + 1 :], n_chars)
|
47 |
+
)
|
src/init.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from brain import Brain
|
2 |
+
|
3 |
+
|
4 |
+
class Initializer:
|
5 |
+
@staticmethod
|
6 |
+
def initialize(augment_api_key, response_api_key, chroma_filename):
|
7 |
+
response_model_name = "gemini-pro"
|
8 |
+
augment_model_name = "models/text-bison-001"
|
9 |
+
generation_config = {
|
10 |
+
"temperature": 0.9,
|
11 |
+
"top_p": 0.7,
|
12 |
+
"top_k": 1,
|
13 |
+
"max_output_tokens": 2048,
|
14 |
+
}
|
15 |
+
response_safety_settings = [
|
16 |
+
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
17 |
+
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
18 |
+
{
|
19 |
+
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
20 |
+
"threshold": "BLOCK_NONE",
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
24 |
+
"threshold": "BLOCK_NONE",
|
25 |
+
},
|
26 |
+
]
|
27 |
+
|
28 |
+
augment_config = {
|
29 |
+
"temperature": 0.9,
|
30 |
+
"top_p": 1,
|
31 |
+
"top_k": 80,
|
32 |
+
"max_output_tokens": 1024,
|
33 |
+
}
|
34 |
+
augment_safety_settings = [
|
35 |
+
{"category": "HARM_CATEGORY_DEROGATORY", "threshold": 4},
|
36 |
+
{"category": "HARM_CATEGORY_TOXICITY", "threshold": 4},
|
37 |
+
{"category": "HARM_CATEGORY_VIOLENCE", "threshold": 4},
|
38 |
+
{"category": "HARM_CATEGORY_SEXUAL", "threshold": 4},
|
39 |
+
{"category": "HARM_CATEGORY_MEDICAL", "threshold": 4},
|
40 |
+
{"category": "HARM_CATEGORY_DANGEROUS", "threshold": 4},
|
41 |
+
]
|
42 |
+
chroma_collection_name = str.upper(chroma_filename) + "_COLLECT"
|
43 |
+
|
44 |
+
return Brain(
|
45 |
+
augment_model_name,
|
46 |
+
augment_config,
|
47 |
+
augment_safety_settings,
|
48 |
+
augment_api_key,
|
49 |
+
response_model_name,
|
50 |
+
generation_config,
|
51 |
+
response_safety_settings,
|
52 |
+
response_api_key,
|
53 |
+
chroma_filename,
|
54 |
+
chroma_collection_name,
|
55 |
+
)
|