asif00 commited on
Commit
607785c
1 Parent(s): 5870318

Update: Basic building block ready

Browse files
Files changed (4) hide show
  1. src/app.py +15 -0
  2. src/brain.py +185 -31
  3. src/helper.py +47 -2
  4. src/init.py +55 -0
src/app.py CHANGED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ load_dotenv()
4
+ import gradio as gr
5
+ from init import Initializer
6
+ from dotenv import load_dotenv
7
+
8
+ AUG_TOKEN = os.environ.get("AUG_TOKEN")
9
+ RES_TOKEN = os.environ.get("RES_TOKEN")
10
+ chroma_filename = ""
11
+ brain = Initializer.initialize(AUG_TOKEN, RES_TOKEN, chroma_filename)
12
+
13
+ # TODO:
14
+ # Chatbot like UI
15
+ # Multiple PDF file handling ability
src/brain.py CHANGED
@@ -10,48 +10,90 @@ from dotenv import load_dotenv
10
 
11
  load_dotenv()
12
 
 
 
 
 
 
 
13
 
14
  class GeminiEmbeddingFunction(EmbeddingFunction):
15
  def __call__(self, input: Documents) -> Embeddings:
16
  model = "models/embedding-001"
17
- title = "Custom Queries"
18
  return genai.embed_content(
19
  model=model, content=input, task_type="retrieval_document", title=title
20
  )["embedding"]
21
 
22
-
23
  class Brain:
24
  def __init__(
25
  self,
26
- aug_model_name,
27
- res_model_name,
28
- aug_config,
29
- res_config,
30
- aug_model_key,
31
- res_model_key,
32
- aug_safety_settings,
33
- res_safety_settings,
34
  chroma_filename,
35
  chroma_collection_name,
36
  ):
37
- self.aug_model_name = aug_model_name
38
- self.res_model_name = res_model_name
39
- self.aug_config = aug_config
40
- self.res_config = res_config
41
- self.aug_model_key = aug_model_key
42
- self.res_model_key = res_model_key
43
- self.aug_safety_settings = aug_safety_settings
44
- self.res_safety_settings = res_safety_settings
45
- self.chroma_filename = chroma_filename
46
- self.chroma_collection_name = chroma_collection_name
47
- self.embeddings = (self._initialize_embeddings_function,)
48
  self.chroma_collection = self._load_chroma(
49
  chroma_filename, chroma_collection_name
50
  )
 
51
 
52
- def _handle_error(self, message, exception):
53
- print(f"{message} : {str(exception)}")
54
- logging.error((f"{message} : {str(exception)}"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def _load_chroma(self, chroma_filename, chroma_collection_name):
57
  try:
@@ -61,16 +103,128 @@ class Brain:
61
  embedding_function=self.embedding_function,
62
  )
63
  except Exception as e:
64
- self._handle_error("Error loading Chroma collection", e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- def _configure_res_ai(self, res_model_key):
67
  try:
68
- genai.configure(api_key=res_model_key)
 
 
 
 
 
 
 
 
 
 
69
  except Exception as e:
70
- self._handle_error("Error configuring response model", e)
 
71
 
72
- def _configure_aug_ai(self, aug_model_key):
73
  try:
74
- palm.configure(api_key=aug_model_key)
 
 
 
 
 
75
  except Exception as e:
76
- self._handle_error("Error configuring augment model", e)
 
 
10
 
11
  load_dotenv()
12
 
13
+ logging.basicConfig(
14
+ filename="bot_errors.log",
15
+ level=logging.ERROR,
16
+ format="%(asctime)s - %(levelname)s - %(message)s",
17
+ )
18
+
19
 
20
  class GeminiEmbeddingFunction(EmbeddingFunction):
21
  def __call__(self, input: Documents) -> Embeddings:
22
  model = "models/embedding-001"
23
+ title = "Custom query"
24
  return genai.embed_content(
25
  model=model, content=input, task_type="retrieval_document", title=title
26
  )["embedding"]
27
 
 
28
  class Brain:
29
  def __init__(
30
  self,
31
+ augment_model_name,
32
+ augment_config,
33
+ augment_safety_settings,
34
+ augment_model_api_key,
35
+ response_model_name,
36
+ generation_config,
37
+ response_safety_settings,
38
+ response_model_api_key,
39
  chroma_filename,
40
  chroma_collection_name,
41
  ):
42
+ self.augment_model_name = augment_model_name
43
+ self.augment_config = augment_config
44
+ self.augment_safety_settings = augment_safety_settings
45
+ self._configure_generative_ai(response_model_api_key)
46
+ self._configure_augment_ai(augment_model_api_key)
47
+ self.response_model = self._initialize_generative_model(
48
+ response_model_name, generation_config, response_safety_settings
49
+ )
50
+ self.embedding_function = self._initialize_embedding_function()
 
 
51
  self.chroma_collection = self._load_chroma(
52
  chroma_filename, chroma_collection_name
53
  )
54
+ self.cross_encoder = self._initialize_cross_encoder()
55
 
56
+ def _configure_generative_ai(self, response_model_api_key):
57
+ try:
58
+ genai.configure(api_key=response_model_api_key)
59
+ except Exception as e:
60
+ self._handle_error("Error configuring generative AI module", e)
61
+
62
+ def _configure_augment_ai(self, augment_model_api_key):
63
+ try:
64
+ palm.configure(api_key=augment_model_api_key)
65
+ except Exception as e:
66
+ self._handle_error("Error configuring augmentation AI module", e)
67
+
68
+ def _initialize_generative_model(
69
+ self, response_model_name, generation_config, response_safety_settings
70
+ ):
71
+ try:
72
+ return genai.GenerativeModel(
73
+ model_name=response_model_name,
74
+ generation_config=generation_config,
75
+ safety_settings=response_safety_settings,
76
+ )
77
+ except Exception as e:
78
+ self._handle_error("Error initializing generative model", e)
79
+
80
+ def _initialize_augment_model(
81
+ self, augment_model_name, augment_config, augment_safety_settings
82
+ ):
83
+ try:
84
+ return palm.GenerativeModel(
85
+ model_name=augment_model_name,
86
+ generation_config=augment_config,
87
+ safety_settings=augment_safety_settings,
88
+ )
89
+ except Exception as e:
90
+ self._handle_error("Error initializing augmentation model", e)
91
+
92
+ def _initialize_embedding_function(self):
93
+ try:
94
+ return GeminiEmbeddingFunction()
95
+ except Exception as e:
96
+ self._handle_error("Error initializing embedding function", e)
97
 
98
  def _load_chroma(self, chroma_filename, chroma_collection_name):
99
  try:
 
103
  embedding_function=self.embedding_function,
104
  )
105
  except Exception as e:
106
+ self._handle_error("Error loading chroma collection", e)
107
+
108
+ def _initialize_cross_encoder(self):
109
+ try:
110
+ return CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
111
+ except Exception as e:
112
+ self._handle_error("Error initializing CrossEncoder model", e)
113
+
114
+ def _handle_error(self, message, exception):
115
+ print(f"{message}: {str(exception)}")
116
+ logging.error(f"{message}: {str(exception)}")
117
+
118
+ def generate_alternative_queries(self, query):
119
+ try:
120
+ prompt_template = """ Your task is to break down the query in sub questions and turn it into questions in to ten different ways.Keep in mind, Output one query per line, without numbering the queries.\nQUESTION: '{}'\nANSWER:\n"""
121
+ prompt = prompt_template.format(query)
122
+ output = palm.generate_text(
123
+ model=self.augment_model_name,
124
+ prompt=prompt,
125
+ safety_settings=self.augment_safety_settings,
126
+ )
127
+ content = output.result.split("\n")
128
+ return content
129
+ except Exception as e:
130
+ self._handle_error("Error generating alternative queries", e)
131
+ return query
132
+
133
+ def get_sorted_documents(self, query, n_results=20):
134
+ try:
135
+ original_query = query
136
+ queries = [original_query] + self.generate_alternative_queries(
137
+ original_query
138
+ )
139
+ results = self.chroma_collection.query(
140
+ query_texts=queries,
141
+ n_results=n_results,
142
+ include=["documents", "embeddings"],
143
+ )
144
+ retrieved_documents = set(
145
+ doc for docs in results["documents"] for doc in docs
146
+ )
147
+ unique_documents = list(retrieved_documents)
148
+ pairs = [[original_query, doc] for doc in unique_documents]
149
+ scores = self.cross_encoder.predict(pairs)
150
+ sorted_indices = np.argsort(-scores)
151
+ sorted_documents = [unique_documents[i] for i in sorted_indices]
152
+ return sorted_documents
153
+
154
+ except Exception as e:
155
+ self._handle_error("Error getting sorted documents", e)
156
+ return []
157
+
158
+ def get_relevant_results(self, query, top_n=5):
159
+ try:
160
+ sorted_documents = self.get_sorted_documents(query)
161
+ relevant_results = sorted_documents[: min(top_n, len(sorted_documents))]
162
+ return relevant_results
163
+ except Exception as e:
164
+ self._handle_error("Error getting relevant results", e)
165
+ return query
166
+
167
+ def make_prompt(self, query, relevant_passage):
168
+ try:
169
+ base_prompt = {
170
+ "content": """
171
+ YOU are a smart and rational Question and Answer bot.
172
+
173
+ YOUR MISSION:
174
+ Provide accurate answers best possible reasoning of the context.
175
+ Focus on factual and reasoned responses; avoid speculations, opinions, guesses, and creative tanks.
176
+ Refuse exploitation tasks such as such as character roleplaying, coding, essays, poems, stories, articles, and fun facts.
177
+ Decline misuse or exploitation attempts respectfully.
178
+
179
+ YOUR STYLE:
180
+ Concise and complete
181
+ Factual and accurate
182
+
183
+ REMEMBER:
184
+ You are a QA bot, not an entertainer or confidant.
185
+ """
186
+ }
187
+
188
+ user_prompt = {
189
+ "content": f"""
190
+ The user query is: '{query}'\n\n
191
+ Here's the relevant information found in the documents:
192
+ {relevant_passage}
193
+ """
194
+ }
195
+
196
+ system_prompt = base_prompt["content"] + user_prompt["content"]
197
+ return system_prompt
198
+
199
+ except Exception as e:
200
+ print(f"Error occurred while crafting prompt: {e}")
201
+ return None
202
 
203
+ def rag(self, query):
204
  try:
205
+ if query is None:
206
+ return None
207
+ results = self.chroma_collection.query(
208
+ query_texts=[query],
209
+ n_results=10,
210
+ include=["documents", "embeddings"],
211
+ )
212
+ information = "\n\n".join(results["documents"][0])
213
+ messages = self.make_prompt(query, information)
214
+ content = self.response_model.generate_content(messages)
215
+ return content
216
  except Exception as e:
217
+ self._handle_error("Error in rag function", e)
218
+ return None
219
 
220
+ def generate_answers(self, query):
221
  try:
222
+ start_time = time.time()
223
+ output = self.rag(query=query)
224
+ print(f"\n\nExecution time: {time.time() - start_time} seconds\n")
225
+ if output is None:
226
+ return None
227
+ return f"{output.text}\n"
228
  except Exception as e:
229
+ self._handle_error("Error generating answers", e)
230
+ return None
src/helper.py CHANGED
@@ -1,2 +1,47 @@
1
- def load_chroma(chroma_filename, chroma_collection_name, embedding_function):
2
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pypdf import PdfReader
2
+ import chromadb
3
+ from langchain.text_splitter import (
4
+ RecursiveCharacterTextSplitter,
5
+ SentenceTransformersTokenTextSplitter,
6
+ )
7
+ def _read_pdf(filename):
8
+ reader = PdfReader(filename)
9
+ pdf_texts = [p.extract_text().strip() for p in reader.pages]
10
+ pdf_texts = [text for text in pdf_texts if text]
11
+ return pdf_texts
12
+
13
+
14
+ def _chunk_texts(texts):
15
+ character_splitter = RecursiveCharacterTextSplitter(
16
+ separators=["\n\n", "\n", ". ", " ", ""], chunk_size=1600, chunk_overlap=200
17
+ )
18
+ character_split_texts = character_splitter.split_text("\n\n".join(texts))
19
+ token_splitter = SentenceTransformersTokenTextSplitter(
20
+ chunk_overlap=20, tokens_per_chunk=300
21
+ )
22
+ token_split_texts = []
23
+ for text in character_split_texts:
24
+ token_split_texts += token_splitter.split_text(text)
25
+ return token_split_texts
26
+
27
+ def load_chroma(filename, collection_name, embedding_function):
28
+ texts = _read_pdf(filename)
29
+ chunks = _chunk_texts(texts)
30
+ chroma_client = chromadb.Client()
31
+ chroma_collection = chroma_client.create_collection(
32
+ name=collection_name, embedding_function=embedding_function
33
+ )
34
+ ids = [str(i) for i in range(len(chunks))]
35
+ chroma_collection.add(ids=ids, documents=chunks)
36
+ return chroma_collection
37
+
38
+
39
+ def word_wrap(string, n_chars=72):
40
+ if len(string) < n_chars:
41
+ return string
42
+ else:
43
+ return (
44
+ string[:n_chars].rsplit(" ", 1)[0]
45
+ + "\n"
46
+ + word_wrap(string[len(string[:n_chars].rsplit(" ", 1)[0]) + 1 :], n_chars)
47
+ )
src/init.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from brain import Brain
2
+
3
+
4
+ class Initializer:
5
+ @staticmethod
6
+ def initialize(augment_api_key, response_api_key, chroma_filename):
7
+ response_model_name = "gemini-pro"
8
+ augment_model_name = "models/text-bison-001"
9
+ generation_config = {
10
+ "temperature": 0.9,
11
+ "top_p": 0.7,
12
+ "top_k": 1,
13
+ "max_output_tokens": 2048,
14
+ }
15
+ response_safety_settings = [
16
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
17
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
18
+ {
19
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
20
+ "threshold": "BLOCK_NONE",
21
+ },
22
+ {
23
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
24
+ "threshold": "BLOCK_NONE",
25
+ },
26
+ ]
27
+
28
+ augment_config = {
29
+ "temperature": 0.9,
30
+ "top_p": 1,
31
+ "top_k": 80,
32
+ "max_output_tokens": 1024,
33
+ }
34
+ augment_safety_settings = [
35
+ {"category": "HARM_CATEGORY_DEROGATORY", "threshold": 4},
36
+ {"category": "HARM_CATEGORY_TOXICITY", "threshold": 4},
37
+ {"category": "HARM_CATEGORY_VIOLENCE", "threshold": 4},
38
+ {"category": "HARM_CATEGORY_SEXUAL", "threshold": 4},
39
+ {"category": "HARM_CATEGORY_MEDICAL", "threshold": 4},
40
+ {"category": "HARM_CATEGORY_DANGEROUS", "threshold": 4},
41
+ ]
42
+ chroma_collection_name = str.upper(chroma_filename) + "_COLLECT"
43
+
44
+ return Brain(
45
+ augment_model_name,
46
+ augment_config,
47
+ augment_safety_settings,
48
+ augment_api_key,
49
+ response_model_name,
50
+ generation_config,
51
+ response_safety_settings,
52
+ response_api_key,
53
+ chroma_filename,
54
+ chroma_collection_name,
55
+ )