Shreyas094 commited on
Commit
041d8cf
·
verified ·
1 Parent(s): 63fcaee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -188
app.py CHANGED
@@ -29,132 +29,66 @@ huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
29
  # Download necessary NLTK data
30
  nltk.download('punkt')
31
  nltk.download('averaged_perceptron_tagger')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- class Agent1:
34
- def __init__(self):
35
- self.question_words = set(["what", "when", "where", "who", "whom", "which", "whose", "why", "how"])
36
- self.conjunctions = set(["and", "or"])
37
- self.pronouns = set(["it", "its", "they", "their", "them", "he", "his", "him", "she", "her", "hers"])
38
- self.context = {}
39
-
40
- def is_question(self, text: str) -> bool:
41
- words = word_tokenize(text.lower())
42
- return (words[0] in self.question_words or
43
- text.strip().endswith('?') or
44
- any(word in self.question_words for word in words))
45
-
46
- def find_subject(self, sentence):
47
- tokens = nltk.pos_tag(word_tokenize(sentence))
48
- subject = None
49
- for word, tag in tokens:
50
- if tag.startswith('NN'):
51
- subject = word
52
- break
53
- if tag == 'IN': # Stop at preposition
54
- break
55
- return subject
56
-
57
- def replace_pronoun(self, questions: List[str]) -> List[str]:
58
- if len(questions) < 2:
59
- return questions
60
-
61
- subject = self.find_subject(questions[0])
62
 
63
- if not subject:
64
- return questions
65
-
66
- for i in range(1, len(questions)):
67
- words = word_tokenize(questions[i])
68
- for j, word in enumerate(words):
69
- if word.lower() in self.pronouns:
70
- words[j] = subject
71
- questions[i] = ' '.join(words)
72
-
73
- return questions
74
-
75
- def rephrase_and_split(self, user_input: str) -> List[str]:
76
- words = word_tokenize(user_input)
77
- questions = []
78
- current_question = []
79
-
80
- for word in words:
81
- if word.lower() in self.conjunctions and current_question:
82
- if self.is_question(' '.join(current_question)):
83
- questions.append(' '.join(current_question))
84
- current_question = []
85
- else:
86
- current_question.append(word)
87
-
88
- if current_question:
89
- if self.is_question(' '.join(current_question)):
90
- questions.append(' '.join(current_question))
91
-
92
- if not questions:
93
- return [user_input]
94
-
95
- questions = self.replace_pronoun(questions)
96
-
97
- return questions
98
-
99
- def update_context(self, query: str):
100
- tokens = nltk.pos_tag(word_tokenize(query))
101
- noun_phrases = []
102
- current_phrase = []
103
-
104
- for word, tag in tokens:
105
- if tag.startswith('NN') or tag.startswith('JJ'):
106
- current_phrase.append(word)
107
- else:
108
- if current_phrase:
109
- noun_phrases.append(' '.join(current_phrase))
110
- current_phrase = []
111
-
112
- if current_phrase:
113
- noun_phrases.append(' '.join(current_phrase))
114
-
115
- if noun_phrases:
116
- self.context['main_topic'] = noun_phrases[0]
117
- self.context['related_topics'] = noun_phrases[1:]
118
- self.context['last_query'] = query
119
-
120
- def apply_context(self, query: str) -> str:
121
- words = word_tokenize(query.lower())
122
-
123
- if (len(words) <= 5 or
124
- any(word in self.pronouns for word in words) or
125
- (self.context.get('main_topic') and self.context['main_topic'].lower() not in query.lower())):
126
-
127
- new_query_parts = []
128
- main_topic_added = False
129
-
130
- for word in words:
131
- if word in self.pronouns and self.context.get('main_topic'):
132
- new_query_parts.append(self.context['main_topic'])
133
- main_topic_added = True
134
- else:
135
- new_query_parts.append(word)
136
-
137
- if not main_topic_added and self.context.get('main_topic'):
138
- new_query_parts.append(f"in the context of {self.context['main_topic']}")
139
-
140
- query = ' '.join(new_query_parts)
141
-
142
- if self.context.get('last_query'):
143
- query = f"{self.context['last_query']} and now {query}"
144
-
145
- return query
146
-
147
- def process(self, user_input: str) -> tuple[List[str], Dict[str, List[Dict[str, str]]]]:
148
- self.update_context(user_input)
149
- contextualized_input = self.apply_context(user_input)
150
- queries = self.rephrase_and_split(contextualized_input)
151
- print("Identified queries:", queries)
152
 
153
- results = {}
154
- for query in queries:
155
- results[query] = google_search(query)
156
 
157
- return queries, results
158
 
159
  def load_document(file: NamedTemporaryFile) -> List[Document]:
160
  """Loads and splits the document into pages."""
@@ -310,13 +244,10 @@ def google_search(term, num_results=5, lang="en", timeout=5, safe="active", ssl_
310
 
311
  return all_results
312
 
313
- def ask_question(question, temperature, top_p, repetition_penalty, web_search, agent1=None):
314
  if not question:
315
  return "Please enter a question."
316
 
317
- if agent1 is None:
318
- agent1 = Agent1()
319
-
320
  model = get_model(temperature, top_p, repetition_penalty)
321
  embed = get_embeddings()
322
 
@@ -328,70 +259,75 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, a
328
  max_attempts = 3
329
  context_reduction_factor = 0.7
330
 
331
- agent1.update_context(question)
332
- contextualized_question = agent1.apply_context(question)
333
 
334
  if web_search:
335
- queries, search_results = agent1.process(contextualized_question)
336
  all_answers = []
337
 
338
- for query in queries:
339
- for attempt in range(max_attempts):
340
- try:
341
- web_docs = [Document(page_content=result["text"], metadata={"source": result["link"], "query": query}) for result in search_results[query] if result["text"]]
342
-
343
- if database is None:
344
- database = FAISS.from_documents(web_docs, embed)
345
- else:
346
- database.add_documents(web_docs)
347
-
348
- database.save_local("faiss_database")
349
-
350
- context_str = "\n".join([f"Query: {doc.metadata['query']}\nSource: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in web_docs])
351
-
352
- prompt_template = """
353
- Answer the question based on the following web search results:
354
- Web Search Results:
355
- {context}
356
- Original Question: {question}
357
- If the web search results don't contain relevant information, state that the information is not available in the search results.
358
- Provide a summarized and direct answer to the original question without mentioning the web search or these instructions.
359
- Do not include any source information in your answer.
360
- """
361
-
362
- prompt_val = ChatPromptTemplate.from_template(prompt_template)
363
- formatted_prompt = prompt_val.format(context=context_str, question=query)
364
-
365
- full_response = generate_chunked_response(model, formatted_prompt)
366
-
367
- answer_patterns = [
368
- r"Provide a concise and direct answer to the question without mentioning the web search or these instructions:",
369
- r"Provide a concise and direct answer to the question:",
370
- r"Answer:",
371
- r"Provide a summarized and direct answer to the original question without mentioning the web search or these instructions:",
372
- r"Do not include any source information in your answer."
373
- ]
374
-
375
- for pattern in answer_patterns:
376
- match = re.split(pattern, full_response, flags=re.IGNORECASE)
377
- if len(match) > 1:
378
- answer = match[-1].strip()
379
- break
380
- else:
381
- answer = full_response.strip()
382
-
383
- all_answers.append(answer)
384
- break
385
-
386
- except Exception as e:
387
- print(f"Error in ask_question for query '{query}' (attempt {attempt + 1}): {e}")
388
- if "Input validation error" in str(e) and attempt < max_attempts - 1:
389
- print(f"Reducing context length for next attempt")
390
- elif attempt == max_attempts - 1:
391
- all_answers.append(f"I apologize, but I'm having trouble processing the query '{query}' due to its length or complexity.")
 
 
 
 
 
 
392
 
393
  answer = "\n\n".join(all_answers)
394
- sources = set(doc.metadata['source'] for docs in search_results.values() for doc in [Document(page_content=result["text"], metadata={"source": result["link"]}) for result in docs if result["text"]])
395
  sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)
396
  answer += sources_section
397
 
@@ -453,9 +389,10 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, a
453
 
454
  return "An unexpected error occurred. Please try again later."
455
 
 
456
  # Gradio interface
457
  with gr.Blocks() as demo:
458
- gr.Markdown("# Chat with your PDF documents and Web Search")
459
 
460
  with gr.Row():
461
  file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
@@ -467,7 +404,7 @@ with gr.Blocks() as demo:
467
  with gr.Row():
468
  with gr.Column(scale=2):
469
  chatbot = gr.Chatbot(label="Conversation")
470
- question_input = gr.Textbox(label="Perplexity AI lite, enable web search to retrieve any web search results. Feel free to provide any feedbacks.")
471
  submit_button = gr.Button("Submit")
472
  with gr.Column(scale=1):
473
  temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
@@ -475,10 +412,10 @@ with gr.Blocks() as demo:
475
  repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.0, step=0.1)
476
  web_search_checkbox = gr.Checkbox(label="Enable Web Search", value=False)
477
 
478
- agent1 = Agent1()
479
 
480
  def chat(question, history, temperature, top_p, repetition_penalty, web_search):
481
- answer = ask_question(question, temperature, top_p, repetition_penalty, web_search, agent1)
482
  history.append((question, answer))
483
  return "", history
484
 
 
29
  # Download necessary NLTK data
30
  nltk.download('punkt')
31
  nltk.download('averaged_perceptron_tagger')
32
+ class ContextDrivenChatbot:
33
+ def __init__(self, history_size=5):
34
+ self.history = []
35
+ self.history_size = history_size
36
+ self.vectorizer = TfidfVectorizer()
37
+ nltk.download('punkt', quiet=True)
38
+ nltk.download('averaged_perceptron_tagger', quiet=True)
39
+
40
+ def add_to_history(self, text):
41
+ self.history.append(text)
42
+ if len(self.history) > self.history_size:
43
+ self.history.pop(0)
44
+
45
+ def get_context(self):
46
+ return " ".join(self.history)
47
+
48
+ def is_follow_up_question(self, question):
49
+ tokens = word_tokenize(question.lower())
50
+ follow_up_indicators = set(['it', 'this', 'that', 'these', 'those', 'he', 'she', 'they', 'them'])
51
+ return any(token in follow_up_indicators for token in tokens)
52
+
53
+ def extract_topics(self, text):
54
+ tokens = nltk.pos_tag(word_tokenize(text))
55
+ return [word for word, pos in tokens if pos.startswith('NN')]
56
+
57
+ def get_most_relevant_context(self, question):
58
+ if not self.history:
59
+ return question
60
+
61
+ # Create a combined context from history
62
+ combined_context = self.get_context()
63
+
64
+ # Vectorize the context and the question
65
+ vectors = self.vectorizer.fit_transform([combined_context, question])
66
+
67
+ # Calculate similarity
68
+ similarity = cosine_similarity(vectors[0], vectors[1])[0][0]
69
+
70
+ # If similarity is low, it might be a new topic
71
+ if similarity < 0.3: # This threshold can be adjusted
72
+ return question
73
+
74
+ # Otherwise, prepend the context
75
+ return f"{combined_context} {question}"
76
 
77
+ def process_question(self, question):
78
+ contextualized_question = self.get_most_relevant_context(question)
79
+
80
+ # Extract topics from the question
81
+ topics = self.extract_topics(question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ # Check if it's a follow-up question
84
+ if self.is_follow_up_question(question):
85
+ # If it's a follow-up, make sure to include previous context
86
+ contextualized_question = f"{self.get_context()} {question}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ # Add the new question to history
89
+ self.add_to_history(question)
 
90
 
91
+ return contextualized_question, topics
92
 
93
  def load_document(file: NamedTemporaryFile) -> List[Document]:
94
  """Loads and splits the document into pages."""
 
244
 
245
  return all_results
246
 
247
+ def ask_question(question, temperature, top_p, repetition_penalty, web_search, chatbot):
248
  if not question:
249
  return "Please enter a question."
250
 
 
 
 
251
  model = get_model(temperature, top_p, repetition_penalty)
252
  embed = get_embeddings()
253
 
 
259
  max_attempts = 3
260
  context_reduction_factor = 0.7
261
 
262
+ contextualized_question, topics = chatbot.process_question(question)
 
263
 
264
  if web_search:
265
+ search_results = google_search(contextualized_question)
266
  all_answers = []
267
 
268
+ for attempt in range(max_attempts):
269
+ try:
270
+ web_docs = [Document(page_content=result["text"], metadata={"source": result["link"]}) for result in search_results if result["text"]]
271
+
272
+ if database is None:
273
+ database = FAISS.from_documents(web_docs, embed)
274
+ else:
275
+ database.add_documents(web_docs)
276
+
277
+ database.save_local("faiss_database")
278
+
279
+ context_str = "\n".join([f"Source: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in web_docs])
280
+
281
+ prompt_template = """
282
+ Answer the question based on the following web search results and conversation context:
283
+ Web Search Results:
284
+ {context}
285
+ Conversation Context: {conv_context}
286
+ Current Question: {question}
287
+ Topics: {topics}
288
+ If the web search results don't contain relevant information, state that the information is not available in the search results.
289
+ Provide a summarized and direct answer to the question without mentioning the web search or these instructions.
290
+ Do not include any source information in your answer.
291
+ """
292
+
293
+ prompt_val = ChatPromptTemplate.from_template(prompt_template)
294
+ formatted_prompt = prompt_val.format(
295
+ context=context_str,
296
+ conv_context=chatbot.get_context(),
297
+ question=question,
298
+ topics=", ".join(topics)
299
+ )
300
+
301
+ full_response = generate_chunked_response(model, formatted_prompt)
302
+
303
+ answer_patterns = [
304
+ r"Provide a concise and direct answer to the question without mentioning the web search or these instructions:",
305
+ r"Provide a concise and direct answer to the question:",
306
+ r"Answer:",
307
+ r"Provide a summarized and direct answer to the original question without mentioning the web search or these instructions:",
308
+ r"Do not include any source information in your answer."
309
+ ]
310
+
311
+ for pattern in answer_patterns:
312
+ match = re.split(pattern, full_response, flags=re.IGNORECASE)
313
+ if len(match) > 1:
314
+ answer = match[-1].strip()
315
+ break
316
+ else:
317
+ answer = full_response.strip()
318
+
319
+ all_answers.append(answer)
320
+ break
321
+
322
+ except Exception as e:
323
+ print(f"Error in ask_question (attempt {attempt + 1}): {e}")
324
+ if "Input validation error" in str(e) and attempt < max_attempts - 1:
325
+ print(f"Reducing context length for next attempt")
326
+ elif attempt == max_attempts - 1:
327
+ all_answers.append(f"I apologize, but I'm having trouble processing the query due to its length or complexity.")
328
 
329
  answer = "\n\n".join(all_answers)
330
+ sources = set(doc.metadata['source'] for doc in web_docs)
331
  sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)
332
  answer += sources_section
333
 
 
389
 
390
  return "An unexpected error occurred. Please try again later."
391
 
392
+ # Gradio interface
393
  # Gradio interface
394
  with gr.Blocks() as demo:
395
+ gr.Markdown("# Context-Driven Conversational Chatbot")
396
 
397
  with gr.Row():
398
  file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
 
404
  with gr.Row():
405
  with gr.Column(scale=2):
406
  chatbot = gr.Chatbot(label="Conversation")
407
+ question_input = gr.Textbox(label="Ask a question")
408
  submit_button = gr.Button("Submit")
409
  with gr.Column(scale=1):
410
  temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
 
412
  repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.0, step=0.1)
413
  web_search_checkbox = gr.Checkbox(label="Enable Web Search", value=False)
414
 
415
+ context_driven_chatbot = ContextDrivenChatbot()
416
 
417
  def chat(question, history, temperature, top_p, repetition_penalty, web_search):
418
+ answer = ask_question(question, temperature, top_p, repetition_penalty, web_search, context_driven_chatbot)
419
  history.append((question, answer))
420
  return "", history
421