Spaces:
Sleeping
Sleeping
Eddie Pick
commited on
Fixes and updates
Browse files- copywriter.py +3 -0
- models.py +280 -0
- requirements.txt +2 -0
- search_agent.py +25 -26
- search_agent_ui.py +24 -16
- web_crawler.py +43 -11
- web_rag.py +79 -64
copywriter.py
CHANGED
@@ -5,6 +5,7 @@ from langchain.prompts.chat import (
|
|
5 |
ChatPromptTemplate
|
6 |
)
|
7 |
from langchain.prompts.prompt import PromptTemplate
|
|
|
8 |
|
9 |
|
10 |
def get_comments_prompt(query, draft):
|
@@ -34,6 +35,7 @@ def get_comments_prompt(query, draft):
|
|
34 |
)
|
35 |
return [system_message, human_message]
|
36 |
|
|
|
37 |
def generate_comments(chat_llm, query, draft, callbacks=[]):
|
38 |
messages = get_comments_prompt(query, draft)
|
39 |
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
@@ -67,6 +69,7 @@ def get_final_text_prompt(query, draft, comments):
|
|
67 |
return [system_message, human_message]
|
68 |
|
69 |
|
|
|
70 |
def generate_final_text(chat_llm, query, draft, comments, callbacks=[]):
|
71 |
messages = get_final_text_prompt(query, draft, comments)
|
72 |
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
|
|
5 |
ChatPromptTemplate
|
6 |
)
|
7 |
from langchain.prompts.prompt import PromptTemplate
|
8 |
+
from langsmith import traceable
|
9 |
|
10 |
|
11 |
def get_comments_prompt(query, draft):
|
|
|
35 |
)
|
36 |
return [system_message, human_message]
|
37 |
|
38 |
+
@traceable(run_type="llm", name="generate_comments")
|
39 |
def generate_comments(chat_llm, query, draft, callbacks=[]):
|
40 |
messages = get_comments_prompt(query, draft)
|
41 |
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
|
|
69 |
return [system_message, human_message]
|
70 |
|
71 |
|
72 |
+
@traceable(run_type="llm", name="generate_final_text")
|
73 |
def generate_final_text(chat_llm, query, draft, comments, callbacks=[]):
|
74 |
messages = get_final_text_prompt(query, draft, comments)
|
75 |
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
models.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from langchain.schema import SystemMessage, HumanMessage
|
4 |
+
from langchain.prompts.chat import (
|
5 |
+
HumanMessagePromptTemplate,
|
6 |
+
SystemMessagePromptTemplate,
|
7 |
+
ChatPromptTemplate
|
8 |
+
)
|
9 |
+
from langchain.prompts.prompt import PromptTemplate
|
10 |
+
from langchain.retrievers.multi_query import MultiQueryRetriever
|
11 |
+
|
12 |
+
from langchain_aws import BedrockEmbeddings
|
13 |
+
from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse
|
14 |
+
from langchain_cohere import ChatCohere
|
15 |
+
from langchain_fireworks.chat_models import ChatFireworks
|
16 |
+
from langchain_fireworks.embeddings import FireworksEmbeddings
|
17 |
+
from langchain_groq.chat_models import ChatGroq
|
18 |
+
from langchain_openai import ChatOpenAI
|
19 |
+
from langchain_openai.embeddings import OpenAIEmbeddings
|
20 |
+
from langchain_ollama.chat_models import ChatOllama
|
21 |
+
from langchain_ollama.embeddings import OllamaEmbeddings
|
22 |
+
from langchain_cohere.embeddings import CohereEmbeddings
|
23 |
+
from langchain_cohere.chat_models import ChatCohere
|
24 |
+
from langchain_openai.embeddings import OpenAIEmbeddings
|
25 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
26 |
+
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
27 |
+
from langchain_community.chat_models import ChatPerplexity
|
28 |
+
from langchain_together import ChatTogether
|
29 |
+
from langchain_together.embeddings import TogetherEmbeddings
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
def get_model(provider_model, temperature=0.0):
|
34 |
+
provider, model = (provider_model.split('/') + [None])[:2]
|
35 |
+
match provider:
|
36 |
+
case 'bedrock':
|
37 |
+
#credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME')
|
38 |
+
if model is None:
|
39 |
+
model = "anthropic.claude-3-sonnet-20240229-v1:0"
|
40 |
+
chat_llm = ChatBedrockConverse(
|
41 |
+
#credentials_profile_name=credentials_profile_name,
|
42 |
+
model=model,
|
43 |
+
temperature=temperature,
|
44 |
+
)
|
45 |
+
case 'cohere':
|
46 |
+
if model is None:
|
47 |
+
model = 'command-r-plus'
|
48 |
+
chat_llm = ChatCohere(model=model, temperature=temperature)
|
49 |
+
case 'fireworks':
|
50 |
+
if model is None:
|
51 |
+
model = 'accounts/fireworks/models/llama-v3p1-8b-instruct'
|
52 |
+
chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=120000)
|
53 |
+
case 'googlegenerativeai':
|
54 |
+
if model is None:
|
55 |
+
model = "gemini-1.5-pro"
|
56 |
+
chat_llm = ChatGoogleGenerativeAI(model=model, temperature=temperature,
|
57 |
+
max_tokens=None, timeout=None, max_retries=2,)
|
58 |
+
case 'groq':
|
59 |
+
if model is None:
|
60 |
+
model = 'llama-3.1-8b-instant'
|
61 |
+
chat_llm = ChatGroq(model_name=model, temperature=temperature)
|
62 |
+
case 'ollama':
|
63 |
+
if model is None:
|
64 |
+
model = 'llama3.1'
|
65 |
+
chat_llm = ChatOllama(model=model, temperature=temperature)
|
66 |
+
case 'openai':
|
67 |
+
if model is None:
|
68 |
+
model = "gpt-4o-mini"
|
69 |
+
chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
|
70 |
+
case 'perplexity':
|
71 |
+
if model is None:
|
72 |
+
model = 'llama-3.1-sonar-small-128k-online'
|
73 |
+
chat_llm = ChatPerplexity(model=model, temperature=temperature)
|
74 |
+
case 'together':
|
75 |
+
if model is None:
|
76 |
+
model = 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
|
77 |
+
chat_llm = ChatTogether(model=model, temperature=temperature)
|
78 |
+
case _:
|
79 |
+
raise ValueError(f"Unknown LLM provider {provider}")
|
80 |
+
|
81 |
+
return chat_llm
|
82 |
+
|
83 |
+
|
84 |
+
def get_embedding_model(provider_embedding_model):
|
85 |
+
provider, model = (provider_embedding_model.split('/') + [None])[:2]
|
86 |
+
match provider:
|
87 |
+
case 'bedrock':
|
88 |
+
#credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME')
|
89 |
+
if model is None:
|
90 |
+
model = "cohere.embed-multilingual-v3"
|
91 |
+
embedding_model = BedrockEmbeddings(
|
92 |
+
model_id=model,
|
93 |
+
#credentials_profile_name=credentials_profile_name
|
94 |
+
)
|
95 |
+
case 'cohere':
|
96 |
+
if model is None:
|
97 |
+
model = "embed-english-light-v3.0"
|
98 |
+
embedding_model = CohereEmbeddings(model=model)
|
99 |
+
case 'fireworks':
|
100 |
+
if model is None:
|
101 |
+
model = 'nomic-ai/nomic-embed-text-v1.5'
|
102 |
+
embedding_model = FireworksEmbeddings(model=model)
|
103 |
+
case 'ollama':
|
104 |
+
if model is None:
|
105 |
+
model = 'nomic-embed-text:latest'
|
106 |
+
embedding_model = OllamaEmbeddings(model=model)
|
107 |
+
case 'openai':
|
108 |
+
if model is None:
|
109 |
+
model = "text-embedding-3-small"
|
110 |
+
embedding_model = OpenAIEmbeddings(model=model)
|
111 |
+
case 'googlegenerativeai':
|
112 |
+
if model is None:
|
113 |
+
model = "models/embedding-001"
|
114 |
+
embedding_model = GoogleGenerativeAIEmbeddings(model=model)
|
115 |
+
case 'groq':
|
116 |
+
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
|
117 |
+
case 'perplexity':
|
118 |
+
raise ValueError(f"Cannot use Perplexity for embedding model")
|
119 |
+
case 'together':
|
120 |
+
if model is None:
|
121 |
+
model = 'BAAI/bge-base-en-v1.5'
|
122 |
+
embedding_model = TogetherEmbeddings(model=model)
|
123 |
+
case _:
|
124 |
+
raise ValueError(f"Unknown LLM provider {provider}")
|
125 |
+
|
126 |
+
return embedding_model
|
127 |
+
|
128 |
+
|
129 |
+
import unittest
|
130 |
+
from unittest.mock import patch
|
131 |
+
from models import get_embedding_model # Make sure this import is correct
|
132 |
+
|
133 |
+
class TestGetEmbeddingModel(unittest.TestCase):
|
134 |
+
|
135 |
+
@patch('models.BedrockEmbeddings')
|
136 |
+
def test_bedrock_embedding(self, mock_bedrock):
|
137 |
+
result = get_embedding_model('bedrock')
|
138 |
+
mock_bedrock.assert_called_once_with(model_id='cohere.embed-multilingual-v3')
|
139 |
+
self.assertEqual(result, mock_bedrock.return_value)
|
140 |
+
|
141 |
+
@patch('models.CohereEmbeddings')
|
142 |
+
def test_cohere_embedding(self, mock_cohere):
|
143 |
+
result = get_embedding_model('cohere')
|
144 |
+
mock_cohere.assert_called_once_with(model='embed-english-light-v3.0')
|
145 |
+
self.assertEqual(result, mock_cohere.return_value)
|
146 |
+
|
147 |
+
@patch('models.FireworksEmbeddings')
|
148 |
+
def test_fireworks_embedding(self, mock_fireworks):
|
149 |
+
result = get_embedding_model('fireworks')
|
150 |
+
mock_fireworks.assert_called_once_with(model='nomic-ai/nomic-embed-text-v1.5')
|
151 |
+
self.assertEqual(result, mock_fireworks.return_value)
|
152 |
+
|
153 |
+
@patch('models.OllamaEmbeddings')
|
154 |
+
def test_ollama_embedding(self, mock_ollama):
|
155 |
+
result = get_embedding_model('ollama')
|
156 |
+
mock_ollama.assert_called_once_with(model='nomic-embed-text:latest')
|
157 |
+
self.assertEqual(result, mock_ollama.return_value)
|
158 |
+
|
159 |
+
@patch('models.OpenAIEmbeddings')
|
160 |
+
def test_openai_embedding(self, mock_openai):
|
161 |
+
result = get_embedding_model('openai')
|
162 |
+
mock_openai.assert_called_once_with(model='text-embedding-3-small')
|
163 |
+
self.assertEqual(result, mock_openai.return_value)
|
164 |
+
|
165 |
+
@patch('models.GoogleGenerativeAIEmbeddings')
|
166 |
+
def test_google_embedding(self, mock_google):
|
167 |
+
result = get_embedding_model('googlegenerativeai')
|
168 |
+
mock_google.assert_called_once_with(model='models/embedding-001')
|
169 |
+
self.assertEqual(result, mock_google.return_value)
|
170 |
+
|
171 |
+
@patch('models.TogetherEmbeddings')
|
172 |
+
def test_together_embedding(self, mock_together):
|
173 |
+
result = get_embedding_model('together')
|
174 |
+
mock_together.assert_called_once_with(model='BAAI/bge-base-en-v1.5')
|
175 |
+
self.assertEqual(result, mock_together.return_value)
|
176 |
+
|
177 |
+
def test_invalid_provider(self):
|
178 |
+
with self.assertRaises(ValueError):
|
179 |
+
get_embedding_model('invalid_provider')
|
180 |
+
|
181 |
+
def test_groq_provider(self):
|
182 |
+
with self.assertRaises(ValueError):
|
183 |
+
get_embedding_model('groq')
|
184 |
+
|
185 |
+
def test_perplexity_provider(self):
|
186 |
+
with self.assertRaises(ValueError):
|
187 |
+
get_embedding_model('perplexity')
|
188 |
+
|
189 |
+
|
190 |
+
import unittest
|
191 |
+
from unittest.mock import patch
|
192 |
+
from models import get_model # Make sure this import is correct
|
193 |
+
|
194 |
+
class TestGetModel(unittest.TestCase):
|
195 |
+
|
196 |
+
@patch('models.ChatBedrockConverse')
|
197 |
+
def test_bedrock_model(self, mock_bedrock):
|
198 |
+
result = get_model('bedrock')
|
199 |
+
mock_bedrock.assert_called_once_with(
|
200 |
+
model="anthropic.claude-3-sonnet-20240229-v1:0",
|
201 |
+
temperature=0.0
|
202 |
+
)
|
203 |
+
self.assertEqual(result, mock_bedrock.return_value)
|
204 |
+
|
205 |
+
@patch('models.ChatCohere')
|
206 |
+
def test_cohere_model(self, mock_cohere):
|
207 |
+
result = get_model('cohere')
|
208 |
+
mock_cohere.assert_called_once_with(model='command-r-plus', temperature=0.0)
|
209 |
+
self.assertEqual(result, mock_cohere.return_value)
|
210 |
+
|
211 |
+
@patch('models.ChatFireworks')
|
212 |
+
def test_fireworks_model(self, mock_fireworks):
|
213 |
+
result = get_model('fireworks')
|
214 |
+
mock_fireworks.assert_called_once_with(
|
215 |
+
model_name='accounts/fireworks/models/llama-v3p1-8b-instruct',
|
216 |
+
temperature=0.0,
|
217 |
+
max_tokens=120000
|
218 |
+
)
|
219 |
+
self.assertEqual(result, mock_fireworks.return_value)
|
220 |
+
|
221 |
+
@patch('models.ChatGoogleGenerativeAI')
|
222 |
+
def test_google_model(self, mock_google):
|
223 |
+
result = get_model('googlegenerativeai')
|
224 |
+
mock_google.assert_called_once_with(
|
225 |
+
model="gemini-1.5-pro",
|
226 |
+
temperature=0.0,
|
227 |
+
max_tokens=None,
|
228 |
+
timeout=None,
|
229 |
+
max_retries=2
|
230 |
+
)
|
231 |
+
self.assertEqual(result, mock_google.return_value)
|
232 |
+
|
233 |
+
@patch('models.ChatGroq')
|
234 |
+
def test_groq_model(self, mock_groq):
|
235 |
+
result = get_model('groq')
|
236 |
+
mock_groq.assert_called_once_with(model_name='llama-3.1-8b-instant', temperature=0.0)
|
237 |
+
self.assertEqual(result, mock_groq.return_value)
|
238 |
+
|
239 |
+
@patch('models.ChatOllama')
|
240 |
+
def test_ollama_model(self, mock_ollama):
|
241 |
+
result = get_model('ollama')
|
242 |
+
mock_ollama.assert_called_once_with(model='llama3.1', temperature=0.0)
|
243 |
+
self.assertEqual(result, mock_ollama.return_value)
|
244 |
+
|
245 |
+
@patch('models.ChatOpenAI')
|
246 |
+
def test_openai_model(self, mock_openai):
|
247 |
+
result = get_model('openai')
|
248 |
+
mock_openai.assert_called_once_with(model_name='gpt-4o-mini', temperature=0.0)
|
249 |
+
self.assertEqual(result, mock_openai.return_value)
|
250 |
+
|
251 |
+
@patch('models.ChatPerplexity')
|
252 |
+
def test_perplexity_model(self, mock_perplexity):
|
253 |
+
result = get_model('perplexity')
|
254 |
+
mock_perplexity.assert_called_once_with(model='llama-3.1-sonar-small-128k-online', temperature=0.0)
|
255 |
+
self.assertEqual(result, mock_perplexity.return_value)
|
256 |
+
|
257 |
+
@patch('models.ChatTogether')
|
258 |
+
def test_together_model(self, mock_together):
|
259 |
+
result = get_model('together')
|
260 |
+
mock_together.assert_called_once_with(model='meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo', temperature=0.0)
|
261 |
+
self.assertEqual(result, mock_together.return_value)
|
262 |
+
|
263 |
+
def test_invalid_provider(self):
|
264 |
+
with self.assertRaises(ValueError):
|
265 |
+
get_model('invalid_provider')
|
266 |
+
|
267 |
+
def test_custom_temperature(self):
|
268 |
+
with patch('models.ChatOpenAI') as mock_openai:
|
269 |
+
result = get_model('openai', temperature=0.5)
|
270 |
+
mock_openai.assert_called_once_with(model_name='gpt-4o-mini', temperature=0.5)
|
271 |
+
self.assertEqual(result, mock_openai.return_value)
|
272 |
+
|
273 |
+
def test_custom_model(self):
|
274 |
+
with patch('models.ChatOpenAI') as mock_openai:
|
275 |
+
result = get_model('openai/gpt-4')
|
276 |
+
mock_openai.assert_called_once_with(model_name='gpt-4', temperature=0.0)
|
277 |
+
self.assertEqual(result, mock_openai.return_value)
|
278 |
+
|
279 |
+
if __name__ == '__main__':
|
280 |
+
unittest.main()
|
requirements.txt
CHANGED
@@ -18,6 +18,8 @@ langchain_experimental
|
|
18 |
langchain_openai
|
19 |
langchain-ollama
|
20 |
langchain_groq
|
|
|
|
|
21 |
langsmith
|
22 |
schema
|
23 |
streamlit
|
|
|
18 |
langchain_openai
|
19 |
langchain-ollama
|
20 |
langchain_groq
|
21 |
+
langchain-google-genai
|
22 |
+
langchain-together
|
23 |
langsmith
|
24 |
schema
|
25 |
streamlit
|
search_agent.py
CHANGED
@@ -5,10 +5,12 @@ Usage:
|
|
5 |
[--domain=domain]
|
6 |
[--provider=provider]
|
7 |
[--model=model]
|
|
|
8 |
[--temperature=temp]
|
9 |
[--copywrite]
|
10 |
[--max_pages=num]
|
11 |
[--max_extracts=num]
|
|
|
12 |
[--output=text]
|
13 |
SEARCH_QUERY
|
14 |
search_agent.py --version
|
@@ -19,10 +21,11 @@ Options:
|
|
19 |
-c --copywrite First produce a draft, review it and rewrite for a final text
|
20 |
-d domain --domain=domain Limit search to a specific domain
|
21 |
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
|
22 |
-
-
|
23 |
-
-
|
24 |
-n num --max_pages=num Max number of pages to retrieve [default: 10]
|
25 |
-e num --max_extracts=num Max number of page extract to consider [default: 5]
|
|
|
26 |
-o text --output=text Output format (choices: text, markdown) [default: markdown]
|
27 |
|
28 |
"""
|
@@ -35,7 +38,7 @@ import dotenv
|
|
35 |
|
36 |
from langchain.callbacks import LangChainTracer
|
37 |
|
38 |
-
from langsmith import Client
|
39 |
|
40 |
from rich.console import Console
|
41 |
from rich.markdown import Markdown
|
@@ -43,6 +46,7 @@ from rich.markdown import Markdown
|
|
43 |
import web_rag as wr
|
44 |
import web_crawler as wc
|
45 |
import copywriter as cw
|
|
|
46 |
|
47 |
console = Console()
|
48 |
dotenv.load_dotenv()
|
@@ -70,34 +74,24 @@ if os.getenv("LANGCHAIN_API_KEY"):
|
|
70 |
callbacks.append(
|
71 |
LangChainTracer(client=Client())
|
72 |
)
|
73 |
-
|
74 |
-
|
75 |
-
arguments = docopt(__doc__, version='Search Agent 0.1')
|
76 |
-
|
77 |
-
#schema = Schema({
|
78 |
-
# '--max_pages': Use(int, error='--max_pages must be an integer'),
|
79 |
-
# '--temperature': Use(float, error='--temperature must be an float'),
|
80 |
-
#})
|
81 |
-
|
82 |
-
#try:
|
83 |
-
# arguments = schema.validate(arguments)
|
84 |
-
#except SchemaError as e:
|
85 |
-
# exit(e)
|
86 |
-
|
87 |
copywrite_mode = arguments["--copywrite"]
|
88 |
-
provider = arguments["--provider"]
|
89 |
model = arguments["--model"]
|
|
|
90 |
temperature = float(arguments["--temperature"])
|
91 |
domain=arguments["--domain"]
|
92 |
-
max_pages=arguments["--max_pages"]
|
93 |
max_extract=int(arguments["--max_extracts"])
|
94 |
output=arguments["--output"]
|
|
|
95 |
query = arguments["SEARCH_QUERY"]
|
96 |
|
97 |
-
chat
|
|
|
98 |
|
99 |
with console.status(f"[bold green]Optimizing query for search: {query}"):
|
100 |
-
optimize_search_query = wr.optimize_search_query(chat, query
|
101 |
if len(optimize_search_query) < 3:
|
102 |
optimize_search_query = query
|
103 |
console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
|
@@ -111,16 +105,16 @@ if __name__ == '__main__':
|
|
111 |
with console.status(
|
112 |
f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"
|
113 |
):
|
114 |
-
contents = wc.get_links_contents(sources, get_selenium_driver)
|
115 |
console.log(f"Managed to extract content from {len(contents)} sources")
|
116 |
|
117 |
with console.status(f"[bold green]Embedding {len(contents)} sources for content", spinner="growVertical"):
|
118 |
vector_store = wc.vectorize(contents, embedding_model)
|
119 |
|
120 |
with console.status("[bold green]Writing content", spinner='dots8Bit'):
|
121 |
-
draft = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = max_extract
|
122 |
|
123 |
-
console.rule(f"[bold green]Response
|
124 |
if output == "text":
|
125 |
console.print(draft)
|
126 |
else:
|
@@ -129,7 +123,7 @@ if __name__ == '__main__':
|
|
129 |
|
130 |
if(copywrite_mode):
|
131 |
with console.status("[bold green]Getting comments from the reviewer", spinner="dots8Bit"):
|
132 |
-
comments = cw.generate_comments(chat, query, draft
|
133 |
|
134 |
console.rule("[bold green]Response from reviewer")
|
135 |
if output == "text":
|
@@ -139,7 +133,7 @@ if __name__ == '__main__':
|
|
139 |
console.rule("[bold green]")
|
140 |
|
141 |
with console.status("[bold green]Writing the final text", spinner="dots8Bit"):
|
142 |
-
final_text = cw.generate_final_text(chat, query, draft, comments
|
143 |
|
144 |
console.rule("[bold green]Final text")
|
145 |
if output == "text":
|
@@ -147,3 +141,8 @@ if __name__ == '__main__':
|
|
147 |
else:
|
148 |
console.print(Markdown(final_text))
|
149 |
console.rule("[bold green]")
|
|
|
|
|
|
|
|
|
|
|
|
5 |
[--domain=domain]
|
6 |
[--provider=provider]
|
7 |
[--model=model]
|
8 |
+
[--embedding_model=model]
|
9 |
[--temperature=temp]
|
10 |
[--copywrite]
|
11 |
[--max_pages=num]
|
12 |
[--max_extracts=num]
|
13 |
+
[--use_selenium]
|
14 |
[--output=text]
|
15 |
SEARCH_QUERY
|
16 |
search_agent.py --version
|
|
|
21 |
-c --copywrite First produce a draft, review it and rewrite for a final text
|
22 |
-d domain --domain=domain Limit search to a specific domain
|
23 |
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
|
24 |
+
-m model --model=model Use a specific model [default: openai/gpt-4o-mini]
|
25 |
+
-e model --embedding_model=model Use a specific embedding model [default: openai/text-embedding-3-small]
|
26 |
-n num --max_pages=num Max number of pages to retrieve [default: 10]
|
27 |
-e num --max_extracts=num Max number of page extract to consider [default: 5]
|
28 |
+
-s --use_selenium Use selenium to fetch content from the web [default: False]
|
29 |
-o text --output=text Output format (choices: text, markdown) [default: markdown]
|
30 |
|
31 |
"""
|
|
|
38 |
|
39 |
from langchain.callbacks import LangChainTracer
|
40 |
|
41 |
+
from langsmith import Client, traceable
|
42 |
|
43 |
from rich.console import Console
|
44 |
from rich.markdown import Markdown
|
|
|
46 |
import web_rag as wr
|
47 |
import web_crawler as wc
|
48 |
import copywriter as cw
|
49 |
+
import models as md
|
50 |
|
51 |
console = Console()
|
52 |
dotenv.load_dotenv()
|
|
|
74 |
callbacks.append(
|
75 |
LangChainTracer(client=Client())
|
76 |
)
|
77 |
+
@traceable(run_type="tool", name="search_agent")
|
78 |
+
def main(arguments):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
copywrite_mode = arguments["--copywrite"]
|
|
|
80 |
model = arguments["--model"]
|
81 |
+
embedding_model = arguments["--embedding_model"]
|
82 |
temperature = float(arguments["--temperature"])
|
83 |
domain=arguments["--domain"]
|
84 |
+
max_pages=int(arguments["--max_pages"])
|
85 |
max_extract=int(arguments["--max_extracts"])
|
86 |
output=arguments["--output"]
|
87 |
+
use_selenium=arguments["--use_selenium"]
|
88 |
query = arguments["SEARCH_QUERY"]
|
89 |
|
90 |
+
chat = md.get_model(model, temperature)
|
91 |
+
embedding_model = md.get_embedding_model(embedding_model)
|
92 |
|
93 |
with console.status(f"[bold green]Optimizing query for search: {query}"):
|
94 |
+
optimize_search_query = wr.optimize_search_query(chat, query)
|
95 |
if len(optimize_search_query) < 3:
|
96 |
optimize_search_query = query
|
97 |
console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
|
|
|
105 |
with console.status(
|
106 |
f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"
|
107 |
):
|
108 |
+
contents = wc.get_links_contents(sources, get_selenium_driver, use_selenium=use_selenium)
|
109 |
console.log(f"Managed to extract content from {len(contents)} sources")
|
110 |
|
111 |
with console.status(f"[bold green]Embedding {len(contents)} sources for content", spinner="growVertical"):
|
112 |
vector_store = wc.vectorize(contents, embedding_model)
|
113 |
|
114 |
with console.status("[bold green]Writing content", spinner='dots8Bit'):
|
115 |
+
draft = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = max_extract)
|
116 |
|
117 |
+
console.rule(f"[bold green]Response")
|
118 |
if output == "text":
|
119 |
console.print(draft)
|
120 |
else:
|
|
|
123 |
|
124 |
if(copywrite_mode):
|
125 |
with console.status("[bold green]Getting comments from the reviewer", spinner="dots8Bit"):
|
126 |
+
comments = cw.generate_comments(chat, query, draft)
|
127 |
|
128 |
console.rule("[bold green]Response from reviewer")
|
129 |
if output == "text":
|
|
|
133 |
console.rule("[bold green]")
|
134 |
|
135 |
with console.status("[bold green]Writing the final text", spinner="dots8Bit"):
|
136 |
+
final_text = cw.generate_final_text(chat, query, draft, comments)
|
137 |
|
138 |
console.rule("[bold green]Final text")
|
139 |
if output == "text":
|
|
|
141 |
else:
|
142 |
console.print(Markdown(final_text))
|
143 |
console.rule("[bold green]")
|
144 |
+
|
145 |
+
if __name__ == '__main__':
|
146 |
+
arguments = docopt(__doc__, version='Search Agent 0.1')
|
147 |
+
main(arguments)
|
148 |
+
|
search_agent_ui.py
CHANGED
@@ -11,7 +11,7 @@ from langsmith.client import Client
|
|
11 |
import web_rag as wr
|
12 |
import web_crawler as wc
|
13 |
import copywriter as cw
|
14 |
-
|
15 |
dotenv.load_dotenv()
|
16 |
|
17 |
ls_tracer = LangChainTracer(
|
@@ -54,26 +54,26 @@ def create_links_markdown(sources_list):
|
|
54 |
st.set_page_config(layout="wide")
|
55 |
st.title("🔍 Simple Search Agent 💬")
|
56 |
|
57 |
-
if "
|
58 |
-
|
59 |
if os.getenv("FIREWORKS_API_KEY"):
|
60 |
-
|
61 |
if os.getenv("COHERE_API_KEY"):
|
62 |
-
|
63 |
if os.getenv("OPENAI_API_KEY"):
|
64 |
-
|
65 |
if os.getenv("GROQ_API_KEY"):
|
66 |
-
|
67 |
if os.getenv("OLLAMA_API_KEY"):
|
68 |
-
|
69 |
if os.getenv("CREDENTIALS_PROFILE_NAME"):
|
70 |
-
|
71 |
-
st.session_state["
|
72 |
|
73 |
with st.sidebar.expander("Options", expanded=False):
|
74 |
-
model_provider = st.selectbox("Model provider 🧠", st.session_state["
|
75 |
temperature = st.slider("Model temperature 🌡️", 0.0, 1.0, 0.1, help="The higher the more creative")
|
76 |
-
max_pages = st.slider("Max pages to retrieve 🔍", 1, 20,
|
77 |
top_k_documents = st.slider("Nbr of doc extracts to consider 📄", 1, 20, 5, help="How many of the top extracts to consider")
|
78 |
reviewer_mode = st.checkbox("Draft / Comment / Rewrite mode ✍️", value=False, help="First generate a draft, then comments and then rewrite")
|
79 |
|
@@ -108,7 +108,8 @@ if prompt := st.chat_input("Enter you instructions..." ):
|
|
108 |
st.chat_message("user").write(prompt)
|
109 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
110 |
|
111 |
-
chat
|
|
|
112 |
|
113 |
with st.status("Thinking", expanded=True):
|
114 |
st.write("I first need to do some research")
|
@@ -120,7 +121,7 @@ if prompt := st.chat_input("Enter you instructions..." ):
|
|
120 |
links_md.markdown(create_links_markdown(sources))
|
121 |
|
122 |
st.write(f"I'll now retrieve the {len(sources)} webpages and documents I found")
|
123 |
-
contents = wc.get_links_contents(sources)
|
124 |
|
125 |
st.write( f"Reading through the {len(contents)} sources I managed to retrieve")
|
126 |
vector_store = wc.vectorize(contents, embedding_model=embedding_model)
|
@@ -147,8 +148,15 @@ if prompt := st.chat_input("Enter you instructions..." ):
|
|
147 |
|
148 |
with st.chat_message("assistant"):
|
149 |
st_cb = StreamHandler(st.empty())
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
message_id = f"{prompt}{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
153 |
st.session_state.messages.append({"role": "assistant", "content": response})
|
154 |
|
|
|
11 |
import web_rag as wr
|
12 |
import web_crawler as wc
|
13 |
import copywriter as cw
|
14 |
+
import models as md
|
15 |
dotenv.load_dotenv()
|
16 |
|
17 |
ls_tracer = LangChainTracer(
|
|
|
54 |
st.set_page_config(layout="wide")
|
55 |
st.title("🔍 Simple Search Agent 💬")
|
56 |
|
57 |
+
if "models" not in st.session_state:
|
58 |
+
models = []
|
59 |
if os.getenv("FIREWORKS_API_KEY"):
|
60 |
+
models.append("fireworks")
|
61 |
if os.getenv("COHERE_API_KEY"):
|
62 |
+
models.append("cohere")
|
63 |
if os.getenv("OPENAI_API_KEY"):
|
64 |
+
models.append("openai")
|
65 |
if os.getenv("GROQ_API_KEY"):
|
66 |
+
models.append("groq")
|
67 |
if os.getenv("OLLAMA_API_KEY"):
|
68 |
+
models.append("ollama")
|
69 |
if os.getenv("CREDENTIALS_PROFILE_NAME"):
|
70 |
+
models.append("bedrock")
|
71 |
+
st.session_state["models"] = models
|
72 |
|
73 |
with st.sidebar.expander("Options", expanded=False):
|
74 |
+
model_provider = st.selectbox("Model provider 🧠", st.session_state["models"])
|
75 |
temperature = st.slider("Model temperature 🌡️", 0.0, 1.0, 0.1, help="The higher the more creative")
|
76 |
+
max_pages = st.slider("Max pages to retrieve 🔍", 1, 20, 10, help="How many web pages to retrive from the internet")
|
77 |
top_k_documents = st.slider("Nbr of doc extracts to consider 📄", 1, 20, 5, help="How many of the top extracts to consider")
|
78 |
reviewer_mode = st.checkbox("Draft / Comment / Rewrite mode ✍️", value=False, help="First generate a draft, then comments and then rewrite")
|
79 |
|
|
|
108 |
st.chat_message("user").write(prompt)
|
109 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
110 |
|
111 |
+
chat = md.get_model(model_provider, temperature)
|
112 |
+
embedding_model = md.get_embedding_model(model_provider)
|
113 |
|
114 |
with st.status("Thinking", expanded=True):
|
115 |
st.write("I first need to do some research")
|
|
|
121 |
links_md.markdown(create_links_markdown(sources))
|
122 |
|
123 |
st.write(f"I'll now retrieve the {len(sources)} webpages and documents I found")
|
124 |
+
contents = wc.get_links_contents(sources, use_selenium=False)
|
125 |
|
126 |
st.write( f"Reading through the {len(contents)} sources I managed to retrieve")
|
127 |
vector_store = wc.vectorize(contents, embedding_model=embedding_model)
|
|
|
148 |
|
149 |
with st.chat_message("assistant"):
|
150 |
st_cb = StreamHandler(st.empty())
|
151 |
+
if hasattr(chat, 'stream'):
|
152 |
+
response = ""
|
153 |
+
for chunk in chat.stream(rag_prompt, config={"callbacks": [st_cb, ls_tracer]}):
|
154 |
+
response += chunk.content
|
155 |
+
else:
|
156 |
+
result = chat.invoke(rag_prompt, config={"callbacks": [st_cb, ls_tracer]})
|
157 |
+
response = result.content
|
158 |
+
|
159 |
+
response = response.strip()
|
160 |
message_id = f"{prompt}{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
161 |
st.session_state.messages.append({"role": "assistant", "content": response})
|
162 |
|
web_crawler.py
CHANGED
@@ -8,12 +8,14 @@ from trafilatura import extract
|
|
8 |
from selenium.common.exceptions import TimeoutException
|
9 |
from langchain_core.documents.base import Document
|
10 |
from langchain_experimental.text_splitter import SemanticChunker
|
|
|
11 |
from langchain_openai import OpenAIEmbeddings
|
12 |
from langchain_community.vectorstores.faiss import FAISS
|
13 |
-
|
14 |
import requests
|
15 |
import pdfplumber
|
16 |
|
|
|
17 |
def get_sources(query, max_pages=10, domain=None):
|
18 |
search_query = query
|
19 |
if domain:
|
@@ -78,8 +80,7 @@ def fetch_with_timeout(url, timeout=8):
|
|
78 |
|
79 |
def process_source(source):
|
80 |
url = source['link']
|
81 |
-
|
82 |
-
response = fetch_with_timeout(url, 8)
|
83 |
if response:
|
84 |
content_type = response.headers.get('Content-Type')
|
85 |
if content_type:
|
@@ -107,12 +108,13 @@ def process_source(source):
|
|
107 |
return {**source, 'page_content': source['snippet']}
|
108 |
return {**source, 'page_content': None}
|
109 |
|
110 |
-
|
|
|
111 |
with ThreadPoolExecutor() as executor:
|
112 |
results = list(executor.map(process_source, sources))
|
113 |
|
114 |
-
if get_driver_func is None:
|
115 |
-
return [result for result in results if result is not None]
|
116 |
|
117 |
for result in results:
|
118 |
if result['page_content'] is None:
|
@@ -125,19 +127,49 @@ def get_links_contents(sources, get_driver_func=None):
|
|
125 |
result['page_content'] = main_content
|
126 |
return results
|
127 |
|
|
|
128 |
def vectorize(contents, embedding_model):
|
129 |
documents = []
|
|
|
130 |
for content in contents:
|
131 |
try:
|
132 |
page_content = content['page_content']
|
133 |
-
if page_content:
|
134 |
metadata = {'title': content['title'], 'source': content['link']}
|
135 |
doc = Document(page_content=content['page_content'], metadata=metadata)
|
136 |
documents.append(doc)
|
|
|
137 |
except Exception as e:
|
138 |
print(f"[gray]Error processing content for {content['link']}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
semantic_chunker = SemanticChunker(embedding_model, breakpoint_threshold_type="percentile")
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from selenium.common.exceptions import TimeoutException
|
9 |
from langchain_core.documents.base import Document
|
10 |
from langchain_experimental.text_splitter import SemanticChunker
|
11 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
12 |
from langchain_openai import OpenAIEmbeddings
|
13 |
from langchain_community.vectorstores.faiss import FAISS
|
14 |
+
from langsmith import traceable
|
15 |
import requests
|
16 |
import pdfplumber
|
17 |
|
18 |
+
@traceable(run_type="tool", name="get_sources")
|
19 |
def get_sources(query, max_pages=10, domain=None):
|
20 |
search_query = query
|
21 |
if domain:
|
|
|
80 |
|
81 |
def process_source(source):
|
82 |
url = source['link']
|
83 |
+
response = fetch_with_timeout(url, 2)
|
|
|
84 |
if response:
|
85 |
content_type = response.headers.get('Content-Type')
|
86 |
if content_type:
|
|
|
108 |
return {**source, 'page_content': source['snippet']}
|
109 |
return {**source, 'page_content': None}
|
110 |
|
111 |
+
@traceable(run_type="tool", name="get_links_contents")
|
112 |
+
def get_links_contents(sources, get_driver_func=None, use_selenium=False):
|
113 |
with ThreadPoolExecutor() as executor:
|
114 |
results = list(executor.map(process_source, sources))
|
115 |
|
116 |
+
if get_driver_func is None or not use_selenium:
|
117 |
+
return [result for result in results if result is not None and result['page_content']]
|
118 |
|
119 |
for result in results:
|
120 |
if result['page_content'] is None:
|
|
|
127 |
result['page_content'] = main_content
|
128 |
return results
|
129 |
|
130 |
+
@traceable(run_type="embedding")
|
131 |
def vectorize(contents, embedding_model):
|
132 |
documents = []
|
133 |
+
total_content_length = 0
|
134 |
for content in contents:
|
135 |
try:
|
136 |
page_content = content['page_content']
|
137 |
+
if page_content:
|
138 |
metadata = {'title': content['title'], 'source': content['link']}
|
139 |
doc = Document(page_content=content['page_content'], metadata=metadata)
|
140 |
documents.append(doc)
|
141 |
+
total_content_length += len(page_content)
|
142 |
except Exception as e:
|
143 |
print(f"[gray]Error processing content for {content['link']}: {e}")
|
144 |
+
|
145 |
+
# Define a threshold for when to use pre-splitting (e.g., 1 million characters)
|
146 |
+
pre_split_threshold = 1_000_000
|
147 |
+
|
148 |
+
if total_content_length > pre_split_threshold:
|
149 |
+
# Use pre-splitting for large datasets
|
150 |
+
pre_splitter = RecursiveCharacterTextSplitter(
|
151 |
+
chunk_size=2000,
|
152 |
+
chunk_overlap=200,
|
153 |
+
length_function=len,
|
154 |
+
)
|
155 |
+
documents = pre_splitter.split_documents(documents)
|
156 |
+
|
157 |
semantic_chunker = SemanticChunker(embedding_model, breakpoint_threshold_type="percentile")
|
158 |
+
|
159 |
+
vector_store = None
|
160 |
+
batch_size = 200 # Adjust this value if needed
|
161 |
+
|
162 |
+
for i in range(0, len(documents), batch_size):
|
163 |
+
batch = documents[i:i+batch_size]
|
164 |
+
|
165 |
+
# Split each document in the batch using SemanticChunker
|
166 |
+
chunked_docs = []
|
167 |
+
for doc in batch:
|
168 |
+
chunked_docs.extend(semantic_chunker.split_documents([doc]))
|
169 |
+
|
170 |
+
if vector_store is None:
|
171 |
+
vector_store = FAISS.from_documents(chunked_docs, embedding_model)
|
172 |
+
else:
|
173 |
+
vector_store.add_documents(chunked_docs)
|
174 |
+
|
175 |
+
return vector_store
|
web_rag.py
CHANGED
@@ -36,53 +36,7 @@ from langchain_groq.chat_models import ChatGroq
|
|
36 |
from langchain_openai import ChatOpenAI
|
37 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
38 |
from langchain_ollama.chat_models import ChatOllama
|
39 |
-
|
40 |
-
def get_models(provider, model=None, temperature=0.0):
|
41 |
-
match provider:
|
42 |
-
case 'bedrock':
|
43 |
-
credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME')
|
44 |
-
if model is None:
|
45 |
-
model = "anthropic.claude-3-sonnet-20240229-v1:0"
|
46 |
-
chat_llm = ChatBedrockConverse(
|
47 |
-
credentials_profile_name=credentials_profile_name,
|
48 |
-
model=model,
|
49 |
-
temperature=temperature,
|
50 |
-
)
|
51 |
-
embedding_model = BedrockEmbeddings(
|
52 |
-
model_id='cohere.embed-multilingual-v3',
|
53 |
-
credentials_profile_name=credentials_profile_name
|
54 |
-
)
|
55 |
-
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
56 |
-
case 'openai':
|
57 |
-
if model is None:
|
58 |
-
model = "gpt-4o-mini"
|
59 |
-
chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
|
60 |
-
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
61 |
-
case 'groq':
|
62 |
-
if model is None:
|
63 |
-
model = 'mixtral-8x7b-32768'
|
64 |
-
chat_llm = ChatGroq(model_name=model, temperature=temperature)
|
65 |
-
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
66 |
-
case 'ollama':
|
67 |
-
if model is None:
|
68 |
-
model = 'llama3.1'
|
69 |
-
chat_llm = ChatOllama(model=model, temperature=temperature)
|
70 |
-
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
71 |
-
case 'cohere':
|
72 |
-
if model is None:
|
73 |
-
model = 'command-r-plus'
|
74 |
-
chat_llm = ChatCohere(model=model, temperature=temperature)
|
75 |
-
#embedding_model = CohereEmbeddings(model="embed-english-light-v3.0")
|
76 |
-
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
77 |
-
case 'fireworks':
|
78 |
-
if model is None:
|
79 |
-
model = 'accounts/fireworks/models/llama-v3p1-8b-instruct'
|
80 |
-
chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=120000)
|
81 |
-
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
82 |
-
case _:
|
83 |
-
raise ValueError(f"Unknown LLM provider {provider}")
|
84 |
-
|
85 |
-
return chat_llm, embedding_model
|
86 |
|
87 |
|
88 |
def get_optimized_search_messages(query):
|
@@ -97,12 +51,13 @@ def get_optimized_search_messages(query):
|
|
97 |
"""
|
98 |
system_message = SystemMessage(
|
99 |
content="""
|
100 |
-
|
101 |
-
|
|
|
102 |
To optimize the prompt:
|
103 |
- Identify the key information being requested
|
|
|
104 |
- Arrange the keywords into a concise search string
|
105 |
-
- Keep it short, around 1 to 5 words total
|
106 |
- Put the most important keywords first
|
107 |
|
108 |
Some tips and things to be sure to remove:
|
@@ -111,7 +66,7 @@ def get_optimized_search_messages(query):
|
|
111 |
- Remove lenght instruction (example: essay, article, letter, blog, post, blogpost, etc)
|
112 |
- Remove style instructions (exmaple: "in the style of", engaging, short, long)
|
113 |
- Remove lenght instruction (example: essay, article, letter, etc)
|
114 |
-
|
115 |
You should answer only with the optimized search query and add "**" to the end of the search string to indicate the end of the query
|
116 |
|
117 |
Example:
|
@@ -119,19 +74,16 @@ def get_optimized_search_messages(query):
|
|
119 |
chocolate chip cookies recipe from scratch**
|
120 |
Example:
|
121 |
Question: I would like you to show me a timeline of Marie Curie's life. Show results as a markdown table
|
122 |
-
Marie Curie timeline**
|
123 |
Example:
|
124 |
Question: I would like you to write a long article on NATO vs Russia. Use known geopolitical frameworks.
|
125 |
geopolitics nato russia**
|
126 |
Example:
|
127 |
Question: Write an engaging LinkedIn post about Andrew Ng
|
128 |
-
Andrew Ng**
|
129 |
Example:
|
130 |
Question: Write a short article about the solar system in the style of Carl Sagan
|
131 |
solar system**
|
132 |
-
Example:
|
133 |
-
Question: Should I use Kubernetes? Answer in the style of Gilfoyle from the TV show Silicon Valley
|
134 |
-
Kubernetes decision**
|
135 |
Example:
|
136 |
Question: Biography of Napoleon. Include a table with the major events.
|
137 |
napoleon biography events**
|
@@ -155,12 +107,73 @@ def get_optimized_search_messages(query):
|
|
155 |
return [system_message, human_message]
|
156 |
|
157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
def optimize_search_query(chat_llm, query, callbacks=[]):
|
159 |
messages = get_optimized_search_messages(query)
|
160 |
-
response = chat_llm.invoke(messages
|
161 |
-
optimized_search_query = response.content
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
def get_rag_prompt_template():
|
166 |
"""
|
@@ -185,8 +198,9 @@ def get_rag_prompt_template():
|
|
185 |
- Format your answer in Markdown, using heading levels 2-3 as needed
|
186 |
- Include a "References" section at the end with the full citations and link for each source you used
|
187 |
|
188 |
-
If
|
189 |
-
If
|
|
|
190 |
"""
|
191 |
)
|
192 |
)
|
@@ -245,7 +259,7 @@ def get_context_size(chat_llm):
|
|
245 |
if isinstance(chat_llm, ChatOllama):
|
246 |
return 120000
|
247 |
if isinstance(chat_llm, ChatCohere):
|
248 |
-
return
|
249 |
if isinstance(chat_llm, ChatBedrockConverse):
|
250 |
if chat_llm.model_id.startswith("meta.llama3-1"):
|
251 |
return 128000
|
@@ -259,7 +273,7 @@ def get_context_size(chat_llm):
|
|
259 |
return 32000
|
260 |
return 4096
|
261 |
|
262 |
-
|
263 |
def build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
|
264 |
done = False
|
265 |
while not done:
|
@@ -275,6 +289,7 @@ def build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k = 10,
|
|
275 |
|
276 |
return prompt
|
277 |
|
|
|
278 |
def query_rag(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
|
279 |
prompt = build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k=top_k, callbacks = callbacks)
|
280 |
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
|
|
|
36 |
from langchain_openai import ChatOpenAI
|
37 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
38 |
from langchain_ollama.chat_models import ChatOllama
|
39 |
+
from langsmith import traceable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
|
42 |
def get_optimized_search_messages(query):
|
|
|
51 |
"""
|
52 |
system_message = SystemMessage(
|
53 |
content="""
|
54 |
+
You are a prompt optimizer for web search. Your task is to take a given chat prompt or question and transform it into an optimized search string that will yield the most relevant and useful information from a search engine like Google.
|
55 |
+
The goal is to create a search query that will help users find the most accurate and pertinent information related to their original prompt or question. An effective search string should be concise, use relevant keywords, and leverage search engine syntax for better results.
|
56 |
+
|
57 |
To optimize the prompt:
|
58 |
- Identify the key information being requested
|
59 |
+
- Consider any implicit information or context that might be useful for the search.
|
60 |
- Arrange the keywords into a concise search string
|
|
|
61 |
- Put the most important keywords first
|
62 |
|
63 |
Some tips and things to be sure to remove:
|
|
|
66 |
- Remove lenght instruction (example: essay, article, letter, blog, post, blogpost, etc)
|
67 |
- Remove style instructions (exmaple: "in the style of", engaging, short, long)
|
68 |
- Remove lenght instruction (example: essay, article, letter, etc)
|
69 |
+
|
70 |
You should answer only with the optimized search query and add "**" to the end of the search string to indicate the end of the query
|
71 |
|
72 |
Example:
|
|
|
74 |
chocolate chip cookies recipe from scratch**
|
75 |
Example:
|
76 |
Question: I would like you to show me a timeline of Marie Curie's life. Show results as a markdown table
|
77 |
+
"Marie Curie" timeline**
|
78 |
Example:
|
79 |
Question: I would like you to write a long article on NATO vs Russia. Use known geopolitical frameworks.
|
80 |
geopolitics nato russia**
|
81 |
Example:
|
82 |
Question: Write an engaging LinkedIn post about Andrew Ng
|
83 |
+
"Andrew Ng"**
|
84 |
Example:
|
85 |
Question: Write a short article about the solar system in the style of Carl Sagan
|
86 |
solar system**
|
|
|
|
|
|
|
87 |
Example:
|
88 |
Question: Biography of Napoleon. Include a table with the major events.
|
89 |
napoleon biography events**
|
|
|
107 |
return [system_message, human_message]
|
108 |
|
109 |
|
110 |
+
|
111 |
+
def get_optimized_search_messages2(query):
|
112 |
+
"""
|
113 |
+
Generate optimized search messages for a given query.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
query (str): The user's query.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
list: A list containing the system message and human message for optimized search.
|
120 |
+
"""
|
121 |
+
system_message = SystemMessage(
|
122 |
+
content="""
|
123 |
+
You are a prompt optimizer for web search. Your task is to take a given chat prompt or question and transform it into an optimized search string that will yield the most relevant and useful information from a search engine like Google.
|
124 |
+
|
125 |
+
The goal is to create a search query that will help users find the most accurate and pertinent information related to their original prompt or question. An effective search string should be concise, use relevant keywords, and leverage search engine syntax for better results.
|
126 |
+
|
127 |
+
Here are some key principles for creating effective search queries:
|
128 |
+
1. Use specific and relevant keywords
|
129 |
+
2. Remove unnecessary words (articles, prepositions, etc.)
|
130 |
+
3. Utilize quotation marks for exact phrases
|
131 |
+
4. Employ Boolean operators (AND, OR, NOT) when appropriate
|
132 |
+
5. Include synonyms or related terms to broaden the search
|
133 |
+
|
134 |
+
I will provide you with a chat prompt or question. Your task is to optimize this into an effective search string.
|
135 |
+
|
136 |
+
Process the input as follows:
|
137 |
+
1. Analyze the Question to identify the main topic and key concepts.
|
138 |
+
2. Extract the most relevant keywords and phrases.
|
139 |
+
3. Consider any implicit information or context that might be useful for the search.
|
140 |
+
|
141 |
+
Then, optimize the search string by:
|
142 |
+
1. Removing filler words and unnecessary language
|
143 |
+
2. Rearranging keywords in a logical order
|
144 |
+
3. Adding quotation marks around exact phrases if applicable
|
145 |
+
4. Including relevant synonyms or related terms (in parentheses) to broaden the search
|
146 |
+
5. Using Boolean operators if needed to refine the search
|
147 |
+
|
148 |
+
You should answer only with the optimized search query and add "**" to the end of the search string to indicate the end of the optimized search query
|
149 |
+
"""
|
150 |
+
)
|
151 |
+
human_message = HumanMessage(
|
152 |
+
content=f"""
|
153 |
+
Question: {query}
|
154 |
+
|
155 |
+
"""
|
156 |
+
)
|
157 |
+
return [system_message, human_message]
|
158 |
+
|
159 |
+
|
160 |
+
@traceable(run_type="llm", name="optimize_search_query")
|
161 |
def optimize_search_query(chat_llm, query, callbacks=[]):
|
162 |
messages = get_optimized_search_messages(query)
|
163 |
+
response = chat_llm.invoke(messages)
|
164 |
+
optimized_search_query = response.content.strip()
|
165 |
+
|
166 |
+
# Split by '**' and take the first part, then strip whitespace
|
167 |
+
optimized_search_query = optimized_search_query.split("**", 1)[0].strip()
|
168 |
+
|
169 |
+
# Remove surrounding quotes if present
|
170 |
+
optimized_search_query = optimized_search_query.strip('"')
|
171 |
+
|
172 |
+
# If the result is empty, fall back to the original query
|
173 |
+
if not optimized_search_query:
|
174 |
+
optimized_search_query = query
|
175 |
+
|
176 |
+
return optimized_search_query
|
177 |
|
178 |
def get_rag_prompt_template():
|
179 |
"""
|
|
|
198 |
- Format your answer in Markdown, using heading levels 2-3 as needed
|
199 |
- Include a "References" section at the end with the full citations and link for each source you used
|
200 |
|
201 |
+
If the provided context is not relevant to the question, say it and answer with your internal knowledge.
|
202 |
+
If you cannot answer the question using either the extracts or your internal knowledge, state that you don't have enough information to provide an accurate answer.
|
203 |
+
If the information in the provided context is in contradiction with your internal knowledge, answer but warn the user about the contradiction.
|
204 |
"""
|
205 |
)
|
206 |
)
|
|
|
259 |
if isinstance(chat_llm, ChatOllama):
|
260 |
return 120000
|
261 |
if isinstance(chat_llm, ChatCohere):
|
262 |
+
return 120000
|
263 |
if isinstance(chat_llm, ChatBedrockConverse):
|
264 |
if chat_llm.model_id.startswith("meta.llama3-1"):
|
265 |
return 128000
|
|
|
273 |
return 32000
|
274 |
return 4096
|
275 |
|
276 |
+
@traceable(run_type="retriever")
|
277 |
def build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
|
278 |
done = False
|
279 |
while not done:
|
|
|
289 |
|
290 |
return prompt
|
291 |
|
292 |
+
@traceable(run_type="llm", name="query_rag")
|
293 |
def query_rag(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
|
294 |
prompt = build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k=top_k, callbacks = callbacks)
|
295 |
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
|