Eddie Pick commited on
Commit
d803be1
·
unverified ·
1 Parent(s): 2f49709

Fixes and updates

Browse files
Files changed (7) hide show
  1. copywriter.py +3 -0
  2. models.py +280 -0
  3. requirements.txt +2 -0
  4. search_agent.py +25 -26
  5. search_agent_ui.py +24 -16
  6. web_crawler.py +43 -11
  7. web_rag.py +79 -64
copywriter.py CHANGED
@@ -5,6 +5,7 @@ from langchain.prompts.chat import (
5
  ChatPromptTemplate
6
  )
7
  from langchain.prompts.prompt import PromptTemplate
 
8
 
9
 
10
  def get_comments_prompt(query, draft):
@@ -34,6 +35,7 @@ def get_comments_prompt(query, draft):
34
  )
35
  return [system_message, human_message]
36
 
 
37
  def generate_comments(chat_llm, query, draft, callbacks=[]):
38
  messages = get_comments_prompt(query, draft)
39
  response = chat_llm.invoke(messages, config={"callbacks": callbacks})
@@ -67,6 +69,7 @@ def get_final_text_prompt(query, draft, comments):
67
  return [system_message, human_message]
68
 
69
 
 
70
  def generate_final_text(chat_llm, query, draft, comments, callbacks=[]):
71
  messages = get_final_text_prompt(query, draft, comments)
72
  response = chat_llm.invoke(messages, config={"callbacks": callbacks})
 
5
  ChatPromptTemplate
6
  )
7
  from langchain.prompts.prompt import PromptTemplate
8
+ from langsmith import traceable
9
 
10
 
11
  def get_comments_prompt(query, draft):
 
35
  )
36
  return [system_message, human_message]
37
 
38
+ @traceable(run_type="llm", name="generate_comments")
39
  def generate_comments(chat_llm, query, draft, callbacks=[]):
40
  messages = get_comments_prompt(query, draft)
41
  response = chat_llm.invoke(messages, config={"callbacks": callbacks})
 
69
  return [system_message, human_message]
70
 
71
 
72
+ @traceable(run_type="llm", name="generate_final_text")
73
  def generate_final_text(chat_llm, query, draft, comments, callbacks=[]):
74
  messages = get_final_text_prompt(query, draft, comments)
75
  response = chat_llm.invoke(messages, config={"callbacks": callbacks})
models.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from langchain.schema import SystemMessage, HumanMessage
4
+ from langchain.prompts.chat import (
5
+ HumanMessagePromptTemplate,
6
+ SystemMessagePromptTemplate,
7
+ ChatPromptTemplate
8
+ )
9
+ from langchain.prompts.prompt import PromptTemplate
10
+ from langchain.retrievers.multi_query import MultiQueryRetriever
11
+
12
+ from langchain_aws import BedrockEmbeddings
13
+ from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse
14
+ from langchain_cohere import ChatCohere
15
+ from langchain_fireworks.chat_models import ChatFireworks
16
+ from langchain_fireworks.embeddings import FireworksEmbeddings
17
+ from langchain_groq.chat_models import ChatGroq
18
+ from langchain_openai import ChatOpenAI
19
+ from langchain_openai.embeddings import OpenAIEmbeddings
20
+ from langchain_ollama.chat_models import ChatOllama
21
+ from langchain_ollama.embeddings import OllamaEmbeddings
22
+ from langchain_cohere.embeddings import CohereEmbeddings
23
+ from langchain_cohere.chat_models import ChatCohere
24
+ from langchain_openai.embeddings import OpenAIEmbeddings
25
+ from langchain_google_genai import ChatGoogleGenerativeAI
26
+ from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
27
+ from langchain_community.chat_models import ChatPerplexity
28
+ from langchain_together import ChatTogether
29
+ from langchain_together.embeddings import TogetherEmbeddings
30
+
31
+
32
+
33
+ def get_model(provider_model, temperature=0.0):
34
+ provider, model = (provider_model.split('/') + [None])[:2]
35
+ match provider:
36
+ case 'bedrock':
37
+ #credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME')
38
+ if model is None:
39
+ model = "anthropic.claude-3-sonnet-20240229-v1:0"
40
+ chat_llm = ChatBedrockConverse(
41
+ #credentials_profile_name=credentials_profile_name,
42
+ model=model,
43
+ temperature=temperature,
44
+ )
45
+ case 'cohere':
46
+ if model is None:
47
+ model = 'command-r-plus'
48
+ chat_llm = ChatCohere(model=model, temperature=temperature)
49
+ case 'fireworks':
50
+ if model is None:
51
+ model = 'accounts/fireworks/models/llama-v3p1-8b-instruct'
52
+ chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=120000)
53
+ case 'googlegenerativeai':
54
+ if model is None:
55
+ model = "gemini-1.5-pro"
56
+ chat_llm = ChatGoogleGenerativeAI(model=model, temperature=temperature,
57
+ max_tokens=None, timeout=None, max_retries=2,)
58
+ case 'groq':
59
+ if model is None:
60
+ model = 'llama-3.1-8b-instant'
61
+ chat_llm = ChatGroq(model_name=model, temperature=temperature)
62
+ case 'ollama':
63
+ if model is None:
64
+ model = 'llama3.1'
65
+ chat_llm = ChatOllama(model=model, temperature=temperature)
66
+ case 'openai':
67
+ if model is None:
68
+ model = "gpt-4o-mini"
69
+ chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
70
+ case 'perplexity':
71
+ if model is None:
72
+ model = 'llama-3.1-sonar-small-128k-online'
73
+ chat_llm = ChatPerplexity(model=model, temperature=temperature)
74
+ case 'together':
75
+ if model is None:
76
+ model = 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
77
+ chat_llm = ChatTogether(model=model, temperature=temperature)
78
+ case _:
79
+ raise ValueError(f"Unknown LLM provider {provider}")
80
+
81
+ return chat_llm
82
+
83
+
84
+ def get_embedding_model(provider_embedding_model):
85
+ provider, model = (provider_embedding_model.split('/') + [None])[:2]
86
+ match provider:
87
+ case 'bedrock':
88
+ #credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME')
89
+ if model is None:
90
+ model = "cohere.embed-multilingual-v3"
91
+ embedding_model = BedrockEmbeddings(
92
+ model_id=model,
93
+ #credentials_profile_name=credentials_profile_name
94
+ )
95
+ case 'cohere':
96
+ if model is None:
97
+ model = "embed-english-light-v3.0"
98
+ embedding_model = CohereEmbeddings(model=model)
99
+ case 'fireworks':
100
+ if model is None:
101
+ model = 'nomic-ai/nomic-embed-text-v1.5'
102
+ embedding_model = FireworksEmbeddings(model=model)
103
+ case 'ollama':
104
+ if model is None:
105
+ model = 'nomic-embed-text:latest'
106
+ embedding_model = OllamaEmbeddings(model=model)
107
+ case 'openai':
108
+ if model is None:
109
+ model = "text-embedding-3-small"
110
+ embedding_model = OpenAIEmbeddings(model=model)
111
+ case 'googlegenerativeai':
112
+ if model is None:
113
+ model = "models/embedding-001"
114
+ embedding_model = GoogleGenerativeAIEmbeddings(model=model)
115
+ case 'groq':
116
+ embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
117
+ case 'perplexity':
118
+ raise ValueError(f"Cannot use Perplexity for embedding model")
119
+ case 'together':
120
+ if model is None:
121
+ model = 'BAAI/bge-base-en-v1.5'
122
+ embedding_model = TogetherEmbeddings(model=model)
123
+ case _:
124
+ raise ValueError(f"Unknown LLM provider {provider}")
125
+
126
+ return embedding_model
127
+
128
+
129
+ import unittest
130
+ from unittest.mock import patch
131
+ from models import get_embedding_model # Make sure this import is correct
132
+
133
+ class TestGetEmbeddingModel(unittest.TestCase):
134
+
135
+ @patch('models.BedrockEmbeddings')
136
+ def test_bedrock_embedding(self, mock_bedrock):
137
+ result = get_embedding_model('bedrock')
138
+ mock_bedrock.assert_called_once_with(model_id='cohere.embed-multilingual-v3')
139
+ self.assertEqual(result, mock_bedrock.return_value)
140
+
141
+ @patch('models.CohereEmbeddings')
142
+ def test_cohere_embedding(self, mock_cohere):
143
+ result = get_embedding_model('cohere')
144
+ mock_cohere.assert_called_once_with(model='embed-english-light-v3.0')
145
+ self.assertEqual(result, mock_cohere.return_value)
146
+
147
+ @patch('models.FireworksEmbeddings')
148
+ def test_fireworks_embedding(self, mock_fireworks):
149
+ result = get_embedding_model('fireworks')
150
+ mock_fireworks.assert_called_once_with(model='nomic-ai/nomic-embed-text-v1.5')
151
+ self.assertEqual(result, mock_fireworks.return_value)
152
+
153
+ @patch('models.OllamaEmbeddings')
154
+ def test_ollama_embedding(self, mock_ollama):
155
+ result = get_embedding_model('ollama')
156
+ mock_ollama.assert_called_once_with(model='nomic-embed-text:latest')
157
+ self.assertEqual(result, mock_ollama.return_value)
158
+
159
+ @patch('models.OpenAIEmbeddings')
160
+ def test_openai_embedding(self, mock_openai):
161
+ result = get_embedding_model('openai')
162
+ mock_openai.assert_called_once_with(model='text-embedding-3-small')
163
+ self.assertEqual(result, mock_openai.return_value)
164
+
165
+ @patch('models.GoogleGenerativeAIEmbeddings')
166
+ def test_google_embedding(self, mock_google):
167
+ result = get_embedding_model('googlegenerativeai')
168
+ mock_google.assert_called_once_with(model='models/embedding-001')
169
+ self.assertEqual(result, mock_google.return_value)
170
+
171
+ @patch('models.TogetherEmbeddings')
172
+ def test_together_embedding(self, mock_together):
173
+ result = get_embedding_model('together')
174
+ mock_together.assert_called_once_with(model='BAAI/bge-base-en-v1.5')
175
+ self.assertEqual(result, mock_together.return_value)
176
+
177
+ def test_invalid_provider(self):
178
+ with self.assertRaises(ValueError):
179
+ get_embedding_model('invalid_provider')
180
+
181
+ def test_groq_provider(self):
182
+ with self.assertRaises(ValueError):
183
+ get_embedding_model('groq')
184
+
185
+ def test_perplexity_provider(self):
186
+ with self.assertRaises(ValueError):
187
+ get_embedding_model('perplexity')
188
+
189
+
190
+ import unittest
191
+ from unittest.mock import patch
192
+ from models import get_model # Make sure this import is correct
193
+
194
+ class TestGetModel(unittest.TestCase):
195
+
196
+ @patch('models.ChatBedrockConverse')
197
+ def test_bedrock_model(self, mock_bedrock):
198
+ result = get_model('bedrock')
199
+ mock_bedrock.assert_called_once_with(
200
+ model="anthropic.claude-3-sonnet-20240229-v1:0",
201
+ temperature=0.0
202
+ )
203
+ self.assertEqual(result, mock_bedrock.return_value)
204
+
205
+ @patch('models.ChatCohere')
206
+ def test_cohere_model(self, mock_cohere):
207
+ result = get_model('cohere')
208
+ mock_cohere.assert_called_once_with(model='command-r-plus', temperature=0.0)
209
+ self.assertEqual(result, mock_cohere.return_value)
210
+
211
+ @patch('models.ChatFireworks')
212
+ def test_fireworks_model(self, mock_fireworks):
213
+ result = get_model('fireworks')
214
+ mock_fireworks.assert_called_once_with(
215
+ model_name='accounts/fireworks/models/llama-v3p1-8b-instruct',
216
+ temperature=0.0,
217
+ max_tokens=120000
218
+ )
219
+ self.assertEqual(result, mock_fireworks.return_value)
220
+
221
+ @patch('models.ChatGoogleGenerativeAI')
222
+ def test_google_model(self, mock_google):
223
+ result = get_model('googlegenerativeai')
224
+ mock_google.assert_called_once_with(
225
+ model="gemini-1.5-pro",
226
+ temperature=0.0,
227
+ max_tokens=None,
228
+ timeout=None,
229
+ max_retries=2
230
+ )
231
+ self.assertEqual(result, mock_google.return_value)
232
+
233
+ @patch('models.ChatGroq')
234
+ def test_groq_model(self, mock_groq):
235
+ result = get_model('groq')
236
+ mock_groq.assert_called_once_with(model_name='llama-3.1-8b-instant', temperature=0.0)
237
+ self.assertEqual(result, mock_groq.return_value)
238
+
239
+ @patch('models.ChatOllama')
240
+ def test_ollama_model(self, mock_ollama):
241
+ result = get_model('ollama')
242
+ mock_ollama.assert_called_once_with(model='llama3.1', temperature=0.0)
243
+ self.assertEqual(result, mock_ollama.return_value)
244
+
245
+ @patch('models.ChatOpenAI')
246
+ def test_openai_model(self, mock_openai):
247
+ result = get_model('openai')
248
+ mock_openai.assert_called_once_with(model_name='gpt-4o-mini', temperature=0.0)
249
+ self.assertEqual(result, mock_openai.return_value)
250
+
251
+ @patch('models.ChatPerplexity')
252
+ def test_perplexity_model(self, mock_perplexity):
253
+ result = get_model('perplexity')
254
+ mock_perplexity.assert_called_once_with(model='llama-3.1-sonar-small-128k-online', temperature=0.0)
255
+ self.assertEqual(result, mock_perplexity.return_value)
256
+
257
+ @patch('models.ChatTogether')
258
+ def test_together_model(self, mock_together):
259
+ result = get_model('together')
260
+ mock_together.assert_called_once_with(model='meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo', temperature=0.0)
261
+ self.assertEqual(result, mock_together.return_value)
262
+
263
+ def test_invalid_provider(self):
264
+ with self.assertRaises(ValueError):
265
+ get_model('invalid_provider')
266
+
267
+ def test_custom_temperature(self):
268
+ with patch('models.ChatOpenAI') as mock_openai:
269
+ result = get_model('openai', temperature=0.5)
270
+ mock_openai.assert_called_once_with(model_name='gpt-4o-mini', temperature=0.5)
271
+ self.assertEqual(result, mock_openai.return_value)
272
+
273
+ def test_custom_model(self):
274
+ with patch('models.ChatOpenAI') as mock_openai:
275
+ result = get_model('openai/gpt-4')
276
+ mock_openai.assert_called_once_with(model_name='gpt-4', temperature=0.0)
277
+ self.assertEqual(result, mock_openai.return_value)
278
+
279
+ if __name__ == '__main__':
280
+ unittest.main()
requirements.txt CHANGED
@@ -18,6 +18,8 @@ langchain_experimental
18
  langchain_openai
19
  langchain-ollama
20
  langchain_groq
 
 
21
  langsmith
22
  schema
23
  streamlit
 
18
  langchain_openai
19
  langchain-ollama
20
  langchain_groq
21
+ langchain-google-genai
22
+ langchain-together
23
  langsmith
24
  schema
25
  streamlit
search_agent.py CHANGED
@@ -5,10 +5,12 @@ Usage:
5
  [--domain=domain]
6
  [--provider=provider]
7
  [--model=model]
 
8
  [--temperature=temp]
9
  [--copywrite]
10
  [--max_pages=num]
11
  [--max_extracts=num]
 
12
  [--output=text]
13
  SEARCH_QUERY
14
  search_agent.py --version
@@ -19,10 +21,11 @@ Options:
19
  -c --copywrite First produce a draft, review it and rewrite for a final text
20
  -d domain --domain=domain Limit search to a specific domain
21
  -t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
22
- -p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama,cohere,fireworks) [default: openai]
23
- -m model --model=model Use a specific model
24
  -n num --max_pages=num Max number of pages to retrieve [default: 10]
25
  -e num --max_extracts=num Max number of page extract to consider [default: 5]
 
26
  -o text --output=text Output format (choices: text, markdown) [default: markdown]
27
 
28
  """
@@ -35,7 +38,7 @@ import dotenv
35
 
36
  from langchain.callbacks import LangChainTracer
37
 
38
- from langsmith import Client
39
 
40
  from rich.console import Console
41
  from rich.markdown import Markdown
@@ -43,6 +46,7 @@ from rich.markdown import Markdown
43
  import web_rag as wr
44
  import web_crawler as wc
45
  import copywriter as cw
 
46
 
47
  console = Console()
48
  dotenv.load_dotenv()
@@ -70,34 +74,24 @@ if os.getenv("LANGCHAIN_API_KEY"):
70
  callbacks.append(
71
  LangChainTracer(client=Client())
72
  )
73
-
74
- if __name__ == '__main__':
75
- arguments = docopt(__doc__, version='Search Agent 0.1')
76
-
77
- #schema = Schema({
78
- # '--max_pages': Use(int, error='--max_pages must be an integer'),
79
- # '--temperature': Use(float, error='--temperature must be an float'),
80
- #})
81
-
82
- #try:
83
- # arguments = schema.validate(arguments)
84
- #except SchemaError as e:
85
- # exit(e)
86
-
87
  copywrite_mode = arguments["--copywrite"]
88
- provider = arguments["--provider"]
89
  model = arguments["--model"]
 
90
  temperature = float(arguments["--temperature"])
91
  domain=arguments["--domain"]
92
- max_pages=arguments["--max_pages"]
93
  max_extract=int(arguments["--max_extracts"])
94
  output=arguments["--output"]
 
95
  query = arguments["SEARCH_QUERY"]
96
 
97
- chat, embedding_model = wr.get_models(provider, model, temperature)
 
98
 
99
  with console.status(f"[bold green]Optimizing query for search: {query}"):
100
- optimize_search_query = wr.optimize_search_query(chat, query, callbacks=callbacks)
101
  if len(optimize_search_query) < 3:
102
  optimize_search_query = query
103
  console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
@@ -111,16 +105,16 @@ if __name__ == '__main__':
111
  with console.status(
112
  f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"
113
  ):
114
- contents = wc.get_links_contents(sources, get_selenium_driver)
115
  console.log(f"Managed to extract content from {len(contents)} sources")
116
 
117
  with console.status(f"[bold green]Embedding {len(contents)} sources for content", spinner="growVertical"):
118
  vector_store = wc.vectorize(contents, embedding_model)
119
 
120
  with console.status("[bold green]Writing content", spinner='dots8Bit'):
121
- draft = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = max_extract, callbacks=callbacks)
122
 
123
- console.rule(f"[bold green]Response from {provider}")
124
  if output == "text":
125
  console.print(draft)
126
  else:
@@ -129,7 +123,7 @@ if __name__ == '__main__':
129
 
130
  if(copywrite_mode):
131
  with console.status("[bold green]Getting comments from the reviewer", spinner="dots8Bit"):
132
- comments = cw.generate_comments(chat, query, draft, callbacks=callbacks)
133
 
134
  console.rule("[bold green]Response from reviewer")
135
  if output == "text":
@@ -139,7 +133,7 @@ if __name__ == '__main__':
139
  console.rule("[bold green]")
140
 
141
  with console.status("[bold green]Writing the final text", spinner="dots8Bit"):
142
- final_text = cw.generate_final_text(chat, query, draft, comments, callbacks=callbacks)
143
 
144
  console.rule("[bold green]Final text")
145
  if output == "text":
@@ -147,3 +141,8 @@ if __name__ == '__main__':
147
  else:
148
  console.print(Markdown(final_text))
149
  console.rule("[bold green]")
 
 
 
 
 
 
5
  [--domain=domain]
6
  [--provider=provider]
7
  [--model=model]
8
+ [--embedding_model=model]
9
  [--temperature=temp]
10
  [--copywrite]
11
  [--max_pages=num]
12
  [--max_extracts=num]
13
+ [--use_selenium]
14
  [--output=text]
15
  SEARCH_QUERY
16
  search_agent.py --version
 
21
  -c --copywrite First produce a draft, review it and rewrite for a final text
22
  -d domain --domain=domain Limit search to a specific domain
23
  -t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
24
+ -m model --model=model Use a specific model [default: openai/gpt-4o-mini]
25
+ -e model --embedding_model=model Use a specific embedding model [default: openai/text-embedding-3-small]
26
  -n num --max_pages=num Max number of pages to retrieve [default: 10]
27
  -e num --max_extracts=num Max number of page extract to consider [default: 5]
28
+ -s --use_selenium Use selenium to fetch content from the web [default: False]
29
  -o text --output=text Output format (choices: text, markdown) [default: markdown]
30
 
31
  """
 
38
 
39
  from langchain.callbacks import LangChainTracer
40
 
41
+ from langsmith import Client, traceable
42
 
43
  from rich.console import Console
44
  from rich.markdown import Markdown
 
46
  import web_rag as wr
47
  import web_crawler as wc
48
  import copywriter as cw
49
+ import models as md
50
 
51
  console = Console()
52
  dotenv.load_dotenv()
 
74
  callbacks.append(
75
  LangChainTracer(client=Client())
76
  )
77
+ @traceable(run_type="tool", name="search_agent")
78
+ def main(arguments):
 
 
 
 
 
 
 
 
 
 
 
 
79
  copywrite_mode = arguments["--copywrite"]
 
80
  model = arguments["--model"]
81
+ embedding_model = arguments["--embedding_model"]
82
  temperature = float(arguments["--temperature"])
83
  domain=arguments["--domain"]
84
+ max_pages=int(arguments["--max_pages"])
85
  max_extract=int(arguments["--max_extracts"])
86
  output=arguments["--output"]
87
+ use_selenium=arguments["--use_selenium"]
88
  query = arguments["SEARCH_QUERY"]
89
 
90
+ chat = md.get_model(model, temperature)
91
+ embedding_model = md.get_embedding_model(embedding_model)
92
 
93
  with console.status(f"[bold green]Optimizing query for search: {query}"):
94
+ optimize_search_query = wr.optimize_search_query(chat, query)
95
  if len(optimize_search_query) < 3:
96
  optimize_search_query = query
97
  console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
 
105
  with console.status(
106
  f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"
107
  ):
108
+ contents = wc.get_links_contents(sources, get_selenium_driver, use_selenium=use_selenium)
109
  console.log(f"Managed to extract content from {len(contents)} sources")
110
 
111
  with console.status(f"[bold green]Embedding {len(contents)} sources for content", spinner="growVertical"):
112
  vector_store = wc.vectorize(contents, embedding_model)
113
 
114
  with console.status("[bold green]Writing content", spinner='dots8Bit'):
115
+ draft = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = max_extract)
116
 
117
+ console.rule(f"[bold green]Response")
118
  if output == "text":
119
  console.print(draft)
120
  else:
 
123
 
124
  if(copywrite_mode):
125
  with console.status("[bold green]Getting comments from the reviewer", spinner="dots8Bit"):
126
+ comments = cw.generate_comments(chat, query, draft)
127
 
128
  console.rule("[bold green]Response from reviewer")
129
  if output == "text":
 
133
  console.rule("[bold green]")
134
 
135
  with console.status("[bold green]Writing the final text", spinner="dots8Bit"):
136
+ final_text = cw.generate_final_text(chat, query, draft, comments)
137
 
138
  console.rule("[bold green]Final text")
139
  if output == "text":
 
141
  else:
142
  console.print(Markdown(final_text))
143
  console.rule("[bold green]")
144
+
145
+ if __name__ == '__main__':
146
+ arguments = docopt(__doc__, version='Search Agent 0.1')
147
+ main(arguments)
148
+
search_agent_ui.py CHANGED
@@ -11,7 +11,7 @@ from langsmith.client import Client
11
  import web_rag as wr
12
  import web_crawler as wc
13
  import copywriter as cw
14
-
15
  dotenv.load_dotenv()
16
 
17
  ls_tracer = LangChainTracer(
@@ -54,26 +54,26 @@ def create_links_markdown(sources_list):
54
  st.set_page_config(layout="wide")
55
  st.title("🔍 Simple Search Agent 💬")
56
 
57
- if "providers" not in st.session_state:
58
- providers = []
59
  if os.getenv("FIREWORKS_API_KEY"):
60
- providers.append("fireworks")
61
  if os.getenv("COHERE_API_KEY"):
62
- providers.append("cohere")
63
  if os.getenv("OPENAI_API_KEY"):
64
- providers.append("openai")
65
  if os.getenv("GROQ_API_KEY"):
66
- providers.append("groq")
67
  if os.getenv("OLLAMA_API_KEY"):
68
- providers.append("ollama")
69
  if os.getenv("CREDENTIALS_PROFILE_NAME"):
70
- providers.append("bedrock")
71
- st.session_state["providers"] = providers
72
 
73
  with st.sidebar.expander("Options", expanded=False):
74
- model_provider = st.selectbox("Model provider 🧠", st.session_state["providers"])
75
  temperature = st.slider("Model temperature 🌡️", 0.0, 1.0, 0.1, help="The higher the more creative")
76
- max_pages = st.slider("Max pages to retrieve 🔍", 1, 20, 15, help="How many web pages to retrive from the internet")
77
  top_k_documents = st.slider("Nbr of doc extracts to consider 📄", 1, 20, 5, help="How many of the top extracts to consider")
78
  reviewer_mode = st.checkbox("Draft / Comment / Rewrite mode ✍️", value=False, help="First generate a draft, then comments and then rewrite")
79
 
@@ -108,7 +108,8 @@ if prompt := st.chat_input("Enter you instructions..." ):
108
  st.chat_message("user").write(prompt)
109
  st.session_state.messages.append({"role": "user", "content": prompt})
110
 
111
- chat, embedding_model = wr.get_models(model_provider, temperature=temperature)
 
112
 
113
  with st.status("Thinking", expanded=True):
114
  st.write("I first need to do some research")
@@ -120,7 +121,7 @@ if prompt := st.chat_input("Enter you instructions..." ):
120
  links_md.markdown(create_links_markdown(sources))
121
 
122
  st.write(f"I'll now retrieve the {len(sources)} webpages and documents I found")
123
- contents = wc.get_links_contents(sources)
124
 
125
  st.write( f"Reading through the {len(contents)} sources I managed to retrieve")
126
  vector_store = wc.vectorize(contents, embedding_model=embedding_model)
@@ -147,8 +148,15 @@ if prompt := st.chat_input("Enter you instructions..." ):
147
 
148
  with st.chat_message("assistant"):
149
  st_cb = StreamHandler(st.empty())
150
- result = chat.invoke(rag_prompt, stream=True, config={ "callbacks": [st_cb, ls_tracer]})
151
- response = result.content.strip()
 
 
 
 
 
 
 
152
  message_id = f"{prompt}{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
153
  st.session_state.messages.append({"role": "assistant", "content": response})
154
 
 
11
  import web_rag as wr
12
  import web_crawler as wc
13
  import copywriter as cw
14
+ import models as md
15
  dotenv.load_dotenv()
16
 
17
  ls_tracer = LangChainTracer(
 
54
  st.set_page_config(layout="wide")
55
  st.title("🔍 Simple Search Agent 💬")
56
 
57
+ if "models" not in st.session_state:
58
+ models = []
59
  if os.getenv("FIREWORKS_API_KEY"):
60
+ models.append("fireworks")
61
  if os.getenv("COHERE_API_KEY"):
62
+ models.append("cohere")
63
  if os.getenv("OPENAI_API_KEY"):
64
+ models.append("openai")
65
  if os.getenv("GROQ_API_KEY"):
66
+ models.append("groq")
67
  if os.getenv("OLLAMA_API_KEY"):
68
+ models.append("ollama")
69
  if os.getenv("CREDENTIALS_PROFILE_NAME"):
70
+ models.append("bedrock")
71
+ st.session_state["models"] = models
72
 
73
  with st.sidebar.expander("Options", expanded=False):
74
+ model_provider = st.selectbox("Model provider 🧠", st.session_state["models"])
75
  temperature = st.slider("Model temperature 🌡️", 0.0, 1.0, 0.1, help="The higher the more creative")
76
+ max_pages = st.slider("Max pages to retrieve 🔍", 1, 20, 10, help="How many web pages to retrive from the internet")
77
  top_k_documents = st.slider("Nbr of doc extracts to consider 📄", 1, 20, 5, help="How many of the top extracts to consider")
78
  reviewer_mode = st.checkbox("Draft / Comment / Rewrite mode ✍️", value=False, help="First generate a draft, then comments and then rewrite")
79
 
 
108
  st.chat_message("user").write(prompt)
109
  st.session_state.messages.append({"role": "user", "content": prompt})
110
 
111
+ chat = md.get_model(model_provider, temperature)
112
+ embedding_model = md.get_embedding_model(model_provider)
113
 
114
  with st.status("Thinking", expanded=True):
115
  st.write("I first need to do some research")
 
121
  links_md.markdown(create_links_markdown(sources))
122
 
123
  st.write(f"I'll now retrieve the {len(sources)} webpages and documents I found")
124
+ contents = wc.get_links_contents(sources, use_selenium=False)
125
 
126
  st.write( f"Reading through the {len(contents)} sources I managed to retrieve")
127
  vector_store = wc.vectorize(contents, embedding_model=embedding_model)
 
148
 
149
  with st.chat_message("assistant"):
150
  st_cb = StreamHandler(st.empty())
151
+ if hasattr(chat, 'stream'):
152
+ response = ""
153
+ for chunk in chat.stream(rag_prompt, config={"callbacks": [st_cb, ls_tracer]}):
154
+ response += chunk.content
155
+ else:
156
+ result = chat.invoke(rag_prompt, config={"callbacks": [st_cb, ls_tracer]})
157
+ response = result.content
158
+
159
+ response = response.strip()
160
  message_id = f"{prompt}{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
161
  st.session_state.messages.append({"role": "assistant", "content": response})
162
 
web_crawler.py CHANGED
@@ -8,12 +8,14 @@ from trafilatura import extract
8
  from selenium.common.exceptions import TimeoutException
9
  from langchain_core.documents.base import Document
10
  from langchain_experimental.text_splitter import SemanticChunker
 
11
  from langchain_openai import OpenAIEmbeddings
12
  from langchain_community.vectorstores.faiss import FAISS
13
-
14
  import requests
15
  import pdfplumber
16
 
 
17
  def get_sources(query, max_pages=10, domain=None):
18
  search_query = query
19
  if domain:
@@ -78,8 +80,7 @@ def fetch_with_timeout(url, timeout=8):
78
 
79
  def process_source(source):
80
  url = source['link']
81
- #console.log(f"Processing {url}")
82
- response = fetch_with_timeout(url, 8)
83
  if response:
84
  content_type = response.headers.get('Content-Type')
85
  if content_type:
@@ -107,12 +108,13 @@ def process_source(source):
107
  return {**source, 'page_content': source['snippet']}
108
  return {**source, 'page_content': None}
109
 
110
- def get_links_contents(sources, get_driver_func=None):
 
111
  with ThreadPoolExecutor() as executor:
112
  results = list(executor.map(process_source, sources))
113
 
114
- if get_driver_func is None:
115
- return [result for result in results if result is not None]
116
 
117
  for result in results:
118
  if result['page_content'] is None:
@@ -125,19 +127,49 @@ def get_links_contents(sources, get_driver_func=None):
125
  result['page_content'] = main_content
126
  return results
127
 
 
128
  def vectorize(contents, embedding_model):
129
  documents = []
 
130
  for content in contents:
131
  try:
132
  page_content = content['page_content']
133
- if page_content: # Sometimes Selenium is not fetching properly
134
  metadata = {'title': content['title'], 'source': content['link']}
135
  doc = Document(page_content=content['page_content'], metadata=metadata)
136
  documents.append(doc)
 
137
  except Exception as e:
138
  print(f"[gray]Error processing content for {content['link']}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  semantic_chunker = SemanticChunker(embedding_model, breakpoint_threshold_type="percentile")
140
- docs = semantic_chunker.split_documents(documents)
141
- embeddings = OpenAIEmbeddings()
142
- store = FAISS.from_documents(docs, embeddings)
143
- return store
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from selenium.common.exceptions import TimeoutException
9
  from langchain_core.documents.base import Document
10
  from langchain_experimental.text_splitter import SemanticChunker
11
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
12
  from langchain_openai import OpenAIEmbeddings
13
  from langchain_community.vectorstores.faiss import FAISS
14
+ from langsmith import traceable
15
  import requests
16
  import pdfplumber
17
 
18
+ @traceable(run_type="tool", name="get_sources")
19
  def get_sources(query, max_pages=10, domain=None):
20
  search_query = query
21
  if domain:
 
80
 
81
  def process_source(source):
82
  url = source['link']
83
+ response = fetch_with_timeout(url, 2)
 
84
  if response:
85
  content_type = response.headers.get('Content-Type')
86
  if content_type:
 
108
  return {**source, 'page_content': source['snippet']}
109
  return {**source, 'page_content': None}
110
 
111
+ @traceable(run_type="tool", name="get_links_contents")
112
+ def get_links_contents(sources, get_driver_func=None, use_selenium=False):
113
  with ThreadPoolExecutor() as executor:
114
  results = list(executor.map(process_source, sources))
115
 
116
+ if get_driver_func is None or not use_selenium:
117
+ return [result for result in results if result is not None and result['page_content']]
118
 
119
  for result in results:
120
  if result['page_content'] is None:
 
127
  result['page_content'] = main_content
128
  return results
129
 
130
+ @traceable(run_type="embedding")
131
  def vectorize(contents, embedding_model):
132
  documents = []
133
+ total_content_length = 0
134
  for content in contents:
135
  try:
136
  page_content = content['page_content']
137
+ if page_content:
138
  metadata = {'title': content['title'], 'source': content['link']}
139
  doc = Document(page_content=content['page_content'], metadata=metadata)
140
  documents.append(doc)
141
+ total_content_length += len(page_content)
142
  except Exception as e:
143
  print(f"[gray]Error processing content for {content['link']}: {e}")
144
+
145
+ # Define a threshold for when to use pre-splitting (e.g., 1 million characters)
146
+ pre_split_threshold = 1_000_000
147
+
148
+ if total_content_length > pre_split_threshold:
149
+ # Use pre-splitting for large datasets
150
+ pre_splitter = RecursiveCharacterTextSplitter(
151
+ chunk_size=2000,
152
+ chunk_overlap=200,
153
+ length_function=len,
154
+ )
155
+ documents = pre_splitter.split_documents(documents)
156
+
157
  semantic_chunker = SemanticChunker(embedding_model, breakpoint_threshold_type="percentile")
158
+
159
+ vector_store = None
160
+ batch_size = 200 # Adjust this value if needed
161
+
162
+ for i in range(0, len(documents), batch_size):
163
+ batch = documents[i:i+batch_size]
164
+
165
+ # Split each document in the batch using SemanticChunker
166
+ chunked_docs = []
167
+ for doc in batch:
168
+ chunked_docs.extend(semantic_chunker.split_documents([doc]))
169
+
170
+ if vector_store is None:
171
+ vector_store = FAISS.from_documents(chunked_docs, embedding_model)
172
+ else:
173
+ vector_store.add_documents(chunked_docs)
174
+
175
+ return vector_store
web_rag.py CHANGED
@@ -36,53 +36,7 @@ from langchain_groq.chat_models import ChatGroq
36
  from langchain_openai import ChatOpenAI
37
  from langchain_openai.embeddings import OpenAIEmbeddings
38
  from langchain_ollama.chat_models import ChatOllama
39
-
40
- def get_models(provider, model=None, temperature=0.0):
41
- match provider:
42
- case 'bedrock':
43
- credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME')
44
- if model is None:
45
- model = "anthropic.claude-3-sonnet-20240229-v1:0"
46
- chat_llm = ChatBedrockConverse(
47
- credentials_profile_name=credentials_profile_name,
48
- model=model,
49
- temperature=temperature,
50
- )
51
- embedding_model = BedrockEmbeddings(
52
- model_id='cohere.embed-multilingual-v3',
53
- credentials_profile_name=credentials_profile_name
54
- )
55
- embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
56
- case 'openai':
57
- if model is None:
58
- model = "gpt-4o-mini"
59
- chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
60
- embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
61
- case 'groq':
62
- if model is None:
63
- model = 'mixtral-8x7b-32768'
64
- chat_llm = ChatGroq(model_name=model, temperature=temperature)
65
- embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
66
- case 'ollama':
67
- if model is None:
68
- model = 'llama3.1'
69
- chat_llm = ChatOllama(model=model, temperature=temperature)
70
- embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
71
- case 'cohere':
72
- if model is None:
73
- model = 'command-r-plus'
74
- chat_llm = ChatCohere(model=model, temperature=temperature)
75
- #embedding_model = CohereEmbeddings(model="embed-english-light-v3.0")
76
- embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
77
- case 'fireworks':
78
- if model is None:
79
- model = 'accounts/fireworks/models/llama-v3p1-8b-instruct'
80
- chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=120000)
81
- embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
82
- case _:
83
- raise ValueError(f"Unknown LLM provider {provider}")
84
-
85
- return chat_llm, embedding_model
86
 
87
 
88
  def get_optimized_search_messages(query):
@@ -97,12 +51,13 @@ def get_optimized_search_messages(query):
97
  """
98
  system_message = SystemMessage(
99
  content="""
100
- I want you to act as a prompt optimizer for web search.
101
- I will provide you with a chat prompt, and your goal is to optimize it into a search string that will yield the most relevant and useful information from a search engine like Google.
 
102
  To optimize the prompt:
103
  - Identify the key information being requested
 
104
  - Arrange the keywords into a concise search string
105
- - Keep it short, around 1 to 5 words total
106
  - Put the most important keywords first
107
 
108
  Some tips and things to be sure to remove:
@@ -111,7 +66,7 @@ def get_optimized_search_messages(query):
111
  - Remove lenght instruction (example: essay, article, letter, blog, post, blogpost, etc)
112
  - Remove style instructions (exmaple: "in the style of", engaging, short, long)
113
  - Remove lenght instruction (example: essay, article, letter, etc)
114
-
115
  You should answer only with the optimized search query and add "**" to the end of the search string to indicate the end of the query
116
 
117
  Example:
@@ -119,19 +74,16 @@ def get_optimized_search_messages(query):
119
  chocolate chip cookies recipe from scratch**
120
  Example:
121
  Question: I would like you to show me a timeline of Marie Curie's life. Show results as a markdown table
122
- Marie Curie timeline**
123
  Example:
124
  Question: I would like you to write a long article on NATO vs Russia. Use known geopolitical frameworks.
125
  geopolitics nato russia**
126
  Example:
127
  Question: Write an engaging LinkedIn post about Andrew Ng
128
- Andrew Ng**
129
  Example:
130
  Question: Write a short article about the solar system in the style of Carl Sagan
131
  solar system**
132
- Example:
133
- Question: Should I use Kubernetes? Answer in the style of Gilfoyle from the TV show Silicon Valley
134
- Kubernetes decision**
135
  Example:
136
  Question: Biography of Napoleon. Include a table with the major events.
137
  napoleon biography events**
@@ -155,12 +107,73 @@ def get_optimized_search_messages(query):
155
  return [system_message, human_message]
156
 
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  def optimize_search_query(chat_llm, query, callbacks=[]):
159
  messages = get_optimized_search_messages(query)
160
- response = chat_llm.invoke(messages, config={"callbacks": callbacks})
161
- optimized_search_query = response.content
162
- return optimized_search_query.strip('"').split("**", 1)[0].strip()
163
-
 
 
 
 
 
 
 
 
 
 
164
 
165
  def get_rag_prompt_template():
166
  """
@@ -185,8 +198,9 @@ def get_rag_prompt_template():
185
  - Format your answer in Markdown, using heading levels 2-3 as needed
186
  - Include a "References" section at the end with the full citations and link for each source you used
187
 
188
- If you cannot answer the question with confidence just say: "I'm not sure about the answer to be honest"
189
- If the provided context is not relevant to the question, just say: "The context provided is not relevant to the question"
 
190
  """
191
  )
192
  )
@@ -245,7 +259,7 @@ def get_context_size(chat_llm):
245
  if isinstance(chat_llm, ChatOllama):
246
  return 120000
247
  if isinstance(chat_llm, ChatCohere):
248
- return 128000
249
  if isinstance(chat_llm, ChatBedrockConverse):
250
  if chat_llm.model_id.startswith("meta.llama3-1"):
251
  return 128000
@@ -259,7 +273,7 @@ def get_context_size(chat_llm):
259
  return 32000
260
  return 4096
261
 
262
-
263
  def build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
264
  done = False
265
  while not done:
@@ -275,6 +289,7 @@ def build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k = 10,
275
 
276
  return prompt
277
 
 
278
  def query_rag(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
279
  prompt = build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k=top_k, callbacks = callbacks)
280
  response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
 
36
  from langchain_openai import ChatOpenAI
37
  from langchain_openai.embeddings import OpenAIEmbeddings
38
  from langchain_ollama.chat_models import ChatOllama
39
+ from langsmith import traceable
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42
  def get_optimized_search_messages(query):
 
51
  """
52
  system_message = SystemMessage(
53
  content="""
54
+ You are a prompt optimizer for web search. Your task is to take a given chat prompt or question and transform it into an optimized search string that will yield the most relevant and useful information from a search engine like Google.
55
+ The goal is to create a search query that will help users find the most accurate and pertinent information related to their original prompt or question. An effective search string should be concise, use relevant keywords, and leverage search engine syntax for better results.
56
+
57
  To optimize the prompt:
58
  - Identify the key information being requested
59
+ - Consider any implicit information or context that might be useful for the search.
60
  - Arrange the keywords into a concise search string
 
61
  - Put the most important keywords first
62
 
63
  Some tips and things to be sure to remove:
 
66
  - Remove lenght instruction (example: essay, article, letter, blog, post, blogpost, etc)
67
  - Remove style instructions (exmaple: "in the style of", engaging, short, long)
68
  - Remove lenght instruction (example: essay, article, letter, etc)
69
+
70
  You should answer only with the optimized search query and add "**" to the end of the search string to indicate the end of the query
71
 
72
  Example:
 
74
  chocolate chip cookies recipe from scratch**
75
  Example:
76
  Question: I would like you to show me a timeline of Marie Curie's life. Show results as a markdown table
77
+ "Marie Curie" timeline**
78
  Example:
79
  Question: I would like you to write a long article on NATO vs Russia. Use known geopolitical frameworks.
80
  geopolitics nato russia**
81
  Example:
82
  Question: Write an engaging LinkedIn post about Andrew Ng
83
+ "Andrew Ng"**
84
  Example:
85
  Question: Write a short article about the solar system in the style of Carl Sagan
86
  solar system**
 
 
 
87
  Example:
88
  Question: Biography of Napoleon. Include a table with the major events.
89
  napoleon biography events**
 
107
  return [system_message, human_message]
108
 
109
 
110
+
111
+ def get_optimized_search_messages2(query):
112
+ """
113
+ Generate optimized search messages for a given query.
114
+
115
+ Args:
116
+ query (str): The user's query.
117
+
118
+ Returns:
119
+ list: A list containing the system message and human message for optimized search.
120
+ """
121
+ system_message = SystemMessage(
122
+ content="""
123
+ You are a prompt optimizer for web search. Your task is to take a given chat prompt or question and transform it into an optimized search string that will yield the most relevant and useful information from a search engine like Google.
124
+
125
+ The goal is to create a search query that will help users find the most accurate and pertinent information related to their original prompt or question. An effective search string should be concise, use relevant keywords, and leverage search engine syntax for better results.
126
+
127
+ Here are some key principles for creating effective search queries:
128
+ 1. Use specific and relevant keywords
129
+ 2. Remove unnecessary words (articles, prepositions, etc.)
130
+ 3. Utilize quotation marks for exact phrases
131
+ 4. Employ Boolean operators (AND, OR, NOT) when appropriate
132
+ 5. Include synonyms or related terms to broaden the search
133
+
134
+ I will provide you with a chat prompt or question. Your task is to optimize this into an effective search string.
135
+
136
+ Process the input as follows:
137
+ 1. Analyze the Question to identify the main topic and key concepts.
138
+ 2. Extract the most relevant keywords and phrases.
139
+ 3. Consider any implicit information or context that might be useful for the search.
140
+
141
+ Then, optimize the search string by:
142
+ 1. Removing filler words and unnecessary language
143
+ 2. Rearranging keywords in a logical order
144
+ 3. Adding quotation marks around exact phrases if applicable
145
+ 4. Including relevant synonyms or related terms (in parentheses) to broaden the search
146
+ 5. Using Boolean operators if needed to refine the search
147
+
148
+ You should answer only with the optimized search query and add "**" to the end of the search string to indicate the end of the optimized search query
149
+ """
150
+ )
151
+ human_message = HumanMessage(
152
+ content=f"""
153
+ Question: {query}
154
+
155
+ """
156
+ )
157
+ return [system_message, human_message]
158
+
159
+
160
+ @traceable(run_type="llm", name="optimize_search_query")
161
  def optimize_search_query(chat_llm, query, callbacks=[]):
162
  messages = get_optimized_search_messages(query)
163
+ response = chat_llm.invoke(messages)
164
+ optimized_search_query = response.content.strip()
165
+
166
+ # Split by '**' and take the first part, then strip whitespace
167
+ optimized_search_query = optimized_search_query.split("**", 1)[0].strip()
168
+
169
+ # Remove surrounding quotes if present
170
+ optimized_search_query = optimized_search_query.strip('"')
171
+
172
+ # If the result is empty, fall back to the original query
173
+ if not optimized_search_query:
174
+ optimized_search_query = query
175
+
176
+ return optimized_search_query
177
 
178
  def get_rag_prompt_template():
179
  """
 
198
  - Format your answer in Markdown, using heading levels 2-3 as needed
199
  - Include a "References" section at the end with the full citations and link for each source you used
200
 
201
+ If the provided context is not relevant to the question, say it and answer with your internal knowledge.
202
+ If you cannot answer the question using either the extracts or your internal knowledge, state that you don't have enough information to provide an accurate answer.
203
+ If the information in the provided context is in contradiction with your internal knowledge, answer but warn the user about the contradiction.
204
  """
205
  )
206
  )
 
259
  if isinstance(chat_llm, ChatOllama):
260
  return 120000
261
  if isinstance(chat_llm, ChatCohere):
262
+ return 120000
263
  if isinstance(chat_llm, ChatBedrockConverse):
264
  if chat_llm.model_id.startswith("meta.llama3-1"):
265
  return 128000
 
273
  return 32000
274
  return 4096
275
 
276
+ @traceable(run_type="retriever")
277
  def build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
278
  done = False
279
  while not done:
 
289
 
290
  return prompt
291
 
292
+ @traceable(run_type="llm", name="query_rag")
293
  def query_rag(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
294
  prompt = build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k=top_k, callbacks = callbacks)
295
  response = chat_llm.invoke(prompt, config={"callbacks": callbacks})