Eddie Pick commited on
Commit
6f80de5
·
unverified ·
1 Parent(s): d803be1

Improvements

Browse files
Files changed (6) hide show
  1. README.md +34 -47
  2. models.py +9 -18
  3. search_agent.py +15 -7
  4. search_agent_ui.py +27 -8
  5. web_crawler.py +18 -27
  6. web_rag.py +17 -1
README.md CHANGED
@@ -10,26 +10,25 @@ pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
- ⚠️ **This project is a demonstration / proof-of-concept and is not intended for use in production environments. It is provided as-is, without warranty or guarantee of any kind. The code and any accompanying materials are for educational, testing, or evaluation purposes only.**⚠️
14
 
15
  # Simple Search Agent
16
 
17
- This Python project provides a search agent that can perform web searches, optimize search queries, fetch and process web content, and generate responses using a language model and the retrieved information.
18
- Does a bit what [Perplexity AI](https://www.perplexity.ai/) does.
19
 
20
  The Streamlit GUI hosted on 🤗 Spaces is [available to test](https://huggingface.co/spaces/CyranoB/search_agent)
21
 
22
- This Python script and Streamli GUI are a basic search agent that utilizes the LangChain library to perform optimized web searches, retrieve relevant content, and generate informative answers to user queries. The script supports multiple language models and providers, including OpenAI, Anthropic, and Groq.
23
 
24
  The main functionality of the script can be summarized as follows:
25
 
26
  1. **Query Optimization**: The user's input query is optimized for web search by identifying the key information requested and transforming it into a concise search string using the language model's capabilities.
27
  2. **Web Search**: The optimized search query is used to fetch search results from the Brave Search API. The script allows limiting the search to a specific domain and setting the maximum number of pages to retrieve.
28
  3. **Content Extraction**: The script fetches the content of the retrieved search results, handling both HTML and PDF documents. It extracts the main text content from web pages and text from PDF files.
29
- 4. **Vectorization**: The extracted content is split into smaller text chunks and vectorized using OpenAI's text embeddings. The vectorized data is stored in a FAISS vector store for efficient retrieval.
30
- 5. **Query Answering**: The user's original query is answered by retrieving the most relevant text chunks from the vector store using a Multi-Query Retriever. The language model generates an informative answer by synthesizing the retrieved information, citing the sources used, and formatting the response in Markdown.
31
 
32
- The script supports various options for customization, such as specifying the language model provider (OpenAI, Anthropic, Groq, or OllaMa), temperature for language model generation, and output format (text or Markdown).
33
 
34
  Additionally, the script integrates with the LangChain Tracing V2 feature, allowing users to monitor and analyze the execution of their LangChain applications using the LangChain Studio.
35
 
@@ -48,58 +47,46 @@ To run the script, users need to provide their API keys for the desired language
48
 
49
  1. Clone this repo
50
  2. Install the required dependencies:
51
- ```
 
52
  pip install -r requirements.txt
53
  ```
 
54
  3. Set up API keys:
55
- - You will need API keys for the web search API and LLM API.
 
56
  - Add your API keys to the `.env` file. Use `dotenv.sample` to create this file.
57
 
58
  ## Usage
59
 
 
 
 
 
60
  ```
61
- python search_agent.py --query "your search query" --provider "provider_name" --model "model_name" --temperature 0.0
62
- ```
63
 
64
- Replace `"your search query"` with your desired search query, `"provider_name"` with the language model provider (e.g., `bedrock`, `openai`, `groq`, `ollama`), `"model_name"` with the specific model name (optional), and `temperature` with the desired temperature value for the language model (optional).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- Example:
 
67
  ```
68
- ➜ python ./search_agent.py --provider groq -o text "Write a linkedin post on how Sequoia Capital AI Ascent 2024 is interesting"
69
- [21:44:05] Using mixtral-8x7b-32768 on groq with temperature 0.0 search_agent.py:78
70
- [21:44:06] Optimized search query: Sequoia Capital AI Ascent 2024 interest search_agent.py:248
71
- Found 10 sources search_agent.py:252
72
- [21:44:08] Managed to extract content from 7 sources search_agent.py:256
73
- [21:44:12] Filtered 21 relevant content extracts search_agent.py:263
74
- ───────────────────────────────────── Response from groq ──────────────────────────────────────
75
- 🚀 Sequoia Capital's AI Ascent 2024 conference brought together some of the brightest minds in
76
- AI, including founders, researchers, and industry leaders. The event was a unique opportunity
77
- to discuss the state of AI and its future, focusing on the promise of generative AI to
78
- revolutionize industries and provide amazing productivity gains.
79
-
80
- 🌟 Highlights of the conference included talks by Sam Altman of OpenAI, Dylan Field of Figma,
81
- Alfred Mensch of Mistral, Daniela Amodei of Anthropic, Andrew Ng of AI Fund, CJ Desai of
82
- ServiceNow, and independent researcher Andrej Karpathy. Sessions covered a wide range of
83
- topics, from the merits of large and small models to the rise of reasoning agents, the future
84
- of compute, and the evolving AI ecosystem.
85
-
86
- 💡 One key takeaway from the event is the recognition that we are in a 'primordial soup' phase
87
- of AI development. This is a crucial moment for the technology to transition from being an idea
88
- to solving real-world problems efficiently. Factors like cheap compute power, fast networks,
89
- ubiquitous supercomputers, and readily available data are enabling AI as the next significant
90
- technology wave.
91
-
92
- 🔜 As we move forward, we can expect AI to become an even more significant part of our lives,
93
- revolutionizing various sectors and offering unprecedented value creation potential. Stay tuned
94
- for the upcoming advancements in AI, and let's continue to explore and harness its vast
95
- capabilities!
96
-
97
- _For more information, check out the [Sequoia Capital AI Ascent 2024 conference
98
- recap](https://www.sequoiacap.com/article/ai-ascent-2024/)._
99
-
100
- #AI #ArtificialIntelligence #GenerativeAI #SequoiaCapital #AIascent2024
101
- ────────────────────────────────────────────── ───────────────────────────────────────────────
102
 
 
 
103
  ```
104
 
105
  ## License
 
10
  license: apache-2.0
11
  ---
12
 
13
+ ⚠️ **This project is a demonstration / proof-of-concept and is not intended for use in production environments. It is provided as-is, without warranty or guarantee of any kind. The code and any accompanying materials are for educational, testing, or evaluation purposes only.** ⚠️
14
 
15
  # Simple Search Agent
16
 
17
+ This Python project provides a search agent that can perform web searches, optimize search queries, fetch and process web content, and generate responses using a language model and the retrieved information. It does a bit of what [Perplexity AI](https://www.perplexity.ai/) does.
 
18
 
19
  The Streamlit GUI hosted on 🤗 Spaces is [available to test](https://huggingface.co/spaces/CyranoB/search_agent)
20
 
21
+ This Python script and Streamlit GUI are a basic search agent that utilizes the LangChain library to perform optimized web searches, retrieve relevant content, and generate informative answers to user queries. The script supports multiple language models and providers, including OpenAI, Anthropic, and Groq.
22
 
23
  The main functionality of the script can be summarized as follows:
24
 
25
  1. **Query Optimization**: The user's input query is optimized for web search by identifying the key information requested and transforming it into a concise search string using the language model's capabilities.
26
  2. **Web Search**: The optimized search query is used to fetch search results from the Brave Search API. The script allows limiting the search to a specific domain and setting the maximum number of pages to retrieve.
27
  3. **Content Extraction**: The script fetches the content of the retrieved search results, handling both HTML and PDF documents. It extracts the main text content from web pages and text from PDF files.
28
+ 4. **Vectorization**: The extracted content is split into smaller text chunks using a RecursiveCharacterTextSplitter and vectorized using the specified embedding model. The vectorized data is stored in a FAISS vector store for efficient retrieval.
29
+ 5. **Query Answering**: The user's original query is answered by retrieving the most relevant text chunks from the vector store. The language model generates an informative answer by synthesizing the retrieved information, citing the sources used, and formatting the response in Markdown.
30
 
31
+ The script supports various options for customization, such as specifying the language model provider (OpenAI, Anthropic, Groq, or Ollama), temperature for language model generation, and output format (text or Markdown).
32
 
33
  Additionally, the script integrates with the LangChain Tracing V2 feature, allowing users to monitor and analyze the execution of their LangChain applications using the LangChain Studio.
34
 
 
47
 
48
  1. Clone this repo
49
  2. Install the required dependencies:
50
+
51
+ ```bash
52
  pip install -r requirements.txt
53
  ```
54
+
55
  3. Set up API keys:
56
+
57
+ - You will need API keys for the Brave Search API and LLM API.
58
  - Add your API keys to the `.env` file. Use `dotenv.sample` to create this file.
59
 
60
  ## Usage
61
 
62
+ You can run the search agent from the command line using the following syntax:
63
+
64
+ ```bash
65
+ python search_agent.py [OPTIONS] SEARCH_QUERY
66
  ```
 
 
67
 
68
+ ### Options:
69
+
70
+ - `-h`, `--help`: Show this help message and exit.
71
+ - `--version`: Show the program's version number and exit.
72
+ - `-c`, `--copywrite`: First produce a draft, review it, and rewrite for a final text.
73
+ - `-d DOMAIN`, `--domain=DOMAIN`: Limit search to a specific domain.
74
+ - `-t TEMP`, `--temperature=TEMP`: Set the temperature of the LLM [default: 0.0].
75
+ - `-m MODEL`, `--model=MODEL`: Use a specific model [default: openai/gpt-4o-mini].
76
+ - `-e MODEL`, `--embedding_model=MODEL`: Use a specific embedding model [default: same provider as model].
77
+ - `-n NUM`, `--max_pages=NUM`: Max number of pages to retrieve [default: 10].
78
+ - `-x NUM`, `--max_extracts=NUM`: Max number of page extracts to consider [default: 7].
79
+ - `-s`, `--use_selenium`: Use selenium to fetch content from the web [default: False].
80
+ - `-o TEXT`, `--output=TEXT`: Output format (choices: text, markdown) [default: markdown].
81
+
82
+ ### Examples
83
 
84
+ ```bash
85
+ python search_agent.py -m openai/gpt-4o-mini "Write a linked post about the current state of M&A for startups. Write in the style of Russ from Silicon Valley TV show."
86
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ ```bash
89
+ python search_agent.py -m openai -e ollama -t 0.7 -n 20 -x 15 "Write a linked post about the state of M&A for startups in 2024. Write in the style of Russ from TV show Silicon Valley" -s
90
  ```
91
 
92
  ## License
models.py CHANGED
@@ -31,17 +31,12 @@ from langchain_together.embeddings import TogetherEmbeddings
31
 
32
 
33
  def get_model(provider_model, temperature=0.0):
34
- provider, model = (provider_model.split('/') + [None])[:2]
35
  match provider:
36
  case 'bedrock':
37
- #credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME')
38
  if model is None:
39
  model = "anthropic.claude-3-sonnet-20240229-v1:0"
40
- chat_llm = ChatBedrockConverse(
41
- #credentials_profile_name=credentials_profile_name,
42
- model=model,
43
- temperature=temperature,
44
- )
45
  case 'cohere':
46
  if model is None:
47
  model = 'command-r-plus'
@@ -52,7 +47,7 @@ def get_model(provider_model, temperature=0.0):
52
  chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=120000)
53
  case 'googlegenerativeai':
54
  if model is None:
55
- model = "gemini-1.5-pro"
56
  chat_llm = ChatGoogleGenerativeAI(model=model, temperature=temperature,
57
  max_tokens=None, timeout=None, max_retries=2,)
58
  case 'groq':
@@ -82,16 +77,12 @@ def get_model(provider_model, temperature=0.0):
82
 
83
 
84
  def get_embedding_model(provider_embedding_model):
85
- provider, model = (provider_embedding_model.split('/') + [None])[:2]
86
  match provider:
87
  case 'bedrock':
88
- #credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME')
89
  if model is None:
90
- model = "cohere.embed-multilingual-v3"
91
- embedding_model = BedrockEmbeddings(
92
- model_id=model,
93
- #credentials_profile_name=credentials_profile_name
94
- )
95
  case 'cohere':
96
  if model is None:
97
  model = "embed-english-light-v3.0"
@@ -118,11 +109,11 @@ def get_embedding_model(provider_embedding_model):
118
  raise ValueError(f"Cannot use Perplexity for embedding model")
119
  case 'together':
120
  if model is None:
121
- model = 'BAAI/bge-base-en-v1.5'
122
  embedding_model = TogetherEmbeddings(model=model)
123
  case _:
124
  raise ValueError(f"Unknown LLM provider {provider}")
125
-
126
  return embedding_model
127
 
128
 
@@ -233,7 +224,7 @@ class TestGetModel(unittest.TestCase):
233
  @patch('models.ChatGroq')
234
  def test_groq_model(self, mock_groq):
235
  result = get_model('groq')
236
- mock_groq.assert_called_once_with(model_name='llama-3.1-8b-instant', temperature=0.0)
237
  self.assertEqual(result, mock_groq.return_value)
238
 
239
  @patch('models.ChatOllama')
 
31
 
32
 
33
  def get_model(provider_model, temperature=0.0):
34
+ provider, model = (provider_model.rstrip('/').split('/') + [None])[:2]
35
  match provider:
36
  case 'bedrock':
 
37
  if model is None:
38
  model = "anthropic.claude-3-sonnet-20240229-v1:0"
39
+ chat_llm = ChatBedrockConverse(model=model, temperature=temperature)
 
 
 
 
40
  case 'cohere':
41
  if model is None:
42
  model = 'command-r-plus'
 
47
  chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=120000)
48
  case 'googlegenerativeai':
49
  if model is None:
50
+ model = "gemini-1.5-flash"
51
  chat_llm = ChatGoogleGenerativeAI(model=model, temperature=temperature,
52
  max_tokens=None, timeout=None, max_retries=2,)
53
  case 'groq':
 
77
 
78
 
79
  def get_embedding_model(provider_embedding_model):
80
+ provider, model = (provider_embedding_model.rstrip('/').split('/') + [None])[:2]
81
  match provider:
82
  case 'bedrock':
 
83
  if model is None:
84
+ model = "amazon.titan-embed-text-v2:0"
85
+ embedding_model = BedrockEmbeddings(model_id=model)
 
 
 
86
  case 'cohere':
87
  if model is None:
88
  model = "embed-english-light-v3.0"
 
109
  raise ValueError(f"Cannot use Perplexity for embedding model")
110
  case 'together':
111
  if model is None:
112
+ model = 'togethercomputer/m2-bert-80M-2k-retrieval'
113
  embedding_model = TogetherEmbeddings(model=model)
114
  case _:
115
  raise ValueError(f"Unknown LLM provider {provider}")
116
+
117
  return embedding_model
118
 
119
 
 
224
  @patch('models.ChatGroq')
225
  def test_groq_model(self, mock_groq):
226
  result = get_model('groq')
227
+ mock_groq.assert_called_once_with(model_name='llama2-70b-4096', temperature=0.0)
228
  self.assertEqual(result, mock_groq.return_value)
229
 
230
  @patch('models.ChatOllama')
search_agent.py CHANGED
@@ -22,9 +22,9 @@ Options:
22
  -d domain --domain=domain Limit search to a specific domain
23
  -t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
24
  -m model --model=model Use a specific model [default: openai/gpt-4o-mini]
25
- -e model --embedding_model=model Use a specific embedding model [default: openai/text-embedding-3-small]
26
  -n num --max_pages=num Max number of pages to retrieve [default: 10]
27
- -e num --max_extracts=num Max number of page extract to consider [default: 5]
28
  -s --use_selenium Use selenium to fetch content from the web [default: False]
29
  -o text --output=text Output format (choices: text, markdown) [default: markdown]
30
 
@@ -54,10 +54,10 @@ dotenv.load_dotenv()
54
  def get_selenium_driver():
55
  from selenium import webdriver
56
  from selenium.webdriver.chrome.options import Options
57
- from selenium.common.exceptions import TimeoutException
58
 
59
  chrome_options = Options()
60
- chrome_options.add_argument("headless")
61
  chrome_options.add_argument("--disable-extensions")
62
  chrome_options.add_argument("--disable-gpu")
63
  chrome_options.add_argument("--no-sandbox")
@@ -66,8 +66,12 @@ def get_selenium_driver():
66
  chrome_options.add_argument('--blink-settings=imagesEnabled=false')
67
  chrome_options.add_argument("--window-size=1920,1080")
68
 
69
- driver = webdriver.Chrome(options=chrome_options)
70
- return driver
 
 
 
 
71
 
72
  callbacks = []
73
  if os.getenv("LANGCHAIN_API_KEY"):
@@ -88,7 +92,11 @@ def main(arguments):
88
  query = arguments["SEARCH_QUERY"]
89
 
90
  chat = md.get_model(model, temperature)
91
- embedding_model = md.get_embedding_model(embedding_model)
 
 
 
 
92
 
93
  with console.status(f"[bold green]Optimizing query for search: {query}"):
94
  optimize_search_query = wr.optimize_search_query(chat, query)
 
22
  -d domain --domain=domain Limit search to a specific domain
23
  -t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
24
  -m model --model=model Use a specific model [default: openai/gpt-4o-mini]
25
+ -e model --embedding_model=model Use a specific embedding model [default: same provider as model]
26
  -n num --max_pages=num Max number of pages to retrieve [default: 10]
27
+ -x num --max_extracts=num Max number of page extract to consider [default: 7]
28
  -s --use_selenium Use selenium to fetch content from the web [default: False]
29
  -o text --output=text Output format (choices: text, markdown) [default: markdown]
30
 
 
54
  def get_selenium_driver():
55
  from selenium import webdriver
56
  from selenium.webdriver.chrome.options import Options
57
+ from selenium.common.exceptions import WebDriverException
58
 
59
  chrome_options = Options()
60
+ chrome_options.add_argument("--headless")
61
  chrome_options.add_argument("--disable-extensions")
62
  chrome_options.add_argument("--disable-gpu")
63
  chrome_options.add_argument("--no-sandbox")
 
66
  chrome_options.add_argument('--blink-settings=imagesEnabled=false')
67
  chrome_options.add_argument("--window-size=1920,1080")
68
 
69
+ try:
70
+ driver = webdriver.Chrome(options=chrome_options)
71
+ return driver
72
+ except WebDriverException as e:
73
+ print(f"Error creating Selenium WebDriver: {e}")
74
+ return None
75
 
76
  callbacks = []
77
  if os.getenv("LANGCHAIN_API_KEY"):
 
92
  query = arguments["SEARCH_QUERY"]
93
 
94
  chat = md.get_model(model, temperature)
95
+ if embedding_model.lower() == "same provider as model":
96
+ provider = model.split('/')[0]
97
+ embedding_model = md.get_embedding_model(f"{provider}/")
98
+ else:
99
+ embedding_model = md.get_embedding_model(embedding_model)
100
 
101
  with console.status(f"[bold green]Optimizing query for search: {query}"):
102
  optimize_search_query = wr.optimize_search_query(chat, query)
search_agent_ui.py CHANGED
@@ -58,6 +58,8 @@ if "models" not in st.session_state:
58
  models = []
59
  if os.getenv("FIREWORKS_API_KEY"):
60
  models.append("fireworks")
 
 
61
  if os.getenv("COHERE_API_KEY"):
62
  models.append("cohere")
63
  if os.getenv("OPENAI_API_KEY"):
@@ -74,7 +76,7 @@ with st.sidebar.expander("Options", expanded=False):
74
  model_provider = st.selectbox("Model provider 🧠", st.session_state["models"])
75
  temperature = st.slider("Model temperature 🌡️", 0.0, 1.0, 0.1, help="The higher the more creative")
76
  max_pages = st.slider("Max pages to retrieve 🔍", 1, 20, 10, help="How many web pages to retrive from the internet")
77
- top_k_documents = st.slider("Nbr of doc extracts to consider 📄", 1, 20, 5, help="How many of the top extracts to consider")
78
  reviewer_mode = st.checkbox("Draft / Comment / Rewrite mode ✍️", value=False, help="First generate a draft, then comments and then rewrite")
79
 
80
  with st.sidebar.expander("Links", expanded=False):
@@ -148,13 +150,30 @@ if prompt := st.chat_input("Enter you instructions..." ):
148
 
149
  with st.chat_message("assistant"):
150
  st_cb = StreamHandler(st.empty())
151
- if hasattr(chat, 'stream'):
152
- response = ""
153
- for chunk in chat.stream(rag_prompt, config={"callbacks": [st_cb, ls_tracer]}):
154
- response += chunk.content
155
- else:
156
- result = chat.invoke(rag_prompt, config={"callbacks": [st_cb, ls_tracer]})
157
- response = result.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  response = response.strip()
160
  message_id = f"{prompt}{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
 
58
  models = []
59
  if os.getenv("FIREWORKS_API_KEY"):
60
  models.append("fireworks")
61
+ if os.getenv("TOGETHER_API_KEY"):
62
+ models.append("together")
63
  if os.getenv("COHERE_API_KEY"):
64
  models.append("cohere")
65
  if os.getenv("OPENAI_API_KEY"):
 
76
  model_provider = st.selectbox("Model provider 🧠", st.session_state["models"])
77
  temperature = st.slider("Model temperature 🌡️", 0.0, 1.0, 0.1, help="The higher the more creative")
78
  max_pages = st.slider("Max pages to retrieve 🔍", 1, 20, 10, help="How many web pages to retrive from the internet")
79
+ top_k_documents = st.slider("Nbr of doc extracts to consider 📄", 1, 20, 10, help="How many of the top extracts to consider")
80
  reviewer_mode = st.checkbox("Draft / Comment / Rewrite mode ✍️", value=False, help="First generate a draft, then comments and then rewrite")
81
 
82
  with st.sidebar.expander("Links", expanded=False):
 
150
 
151
  with st.chat_message("assistant"):
152
  st_cb = StreamHandler(st.empty())
153
+ response = ""
154
+ for chunk in chat.stream(rag_prompt, config={"callbacks": [ls_tracer]}):
155
+ if isinstance(chunk, dict):
156
+ chunk_text = chunk.get('text') or chunk.get('content', '')
157
+ elif isinstance(chunk, str):
158
+ chunk_text = chunk
159
+ elif hasattr(chunk, 'content'):
160
+ chunk_text = chunk.content
161
+ else:
162
+ chunk_text = str(chunk)
163
+
164
+ if isinstance(chunk_text, list):
165
+ chunk_text = ' '.join(
166
+ item['text'] if isinstance(item, dict) and 'text' in item
167
+ else str(item)
168
+ for item in chunk_text if item is not None
169
+ )
170
+ elif chunk_text is not None:
171
+ chunk_text = str(chunk_text)
172
+ else:
173
+ continue
174
+
175
+ response += chunk_text
176
+ st_cb.on_llm_new_token(chunk_text)
177
 
178
  response = response.strip()
179
  message_id = f"{prompt}{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
web_crawler.py CHANGED
@@ -8,8 +8,7 @@ from trafilatura import extract
8
  from selenium.common.exceptions import TimeoutException
9
  from langchain_core.documents.base import Document
10
  from langchain_experimental.text_splitter import SemanticChunker
11
- from langchain.text_splitter import RecursiveCharacterTextSplitter
12
- from langchain_openai import OpenAIEmbeddings
13
  from langchain_community.vectorstores.faiss import FAISS
14
  from langsmith import traceable
15
  import requests
@@ -130,7 +129,6 @@ def get_links_contents(sources, get_driver_func=None, use_selenium=False):
130
  @traceable(run_type="embedding")
131
  def vectorize(contents, embedding_model):
132
  documents = []
133
- total_content_length = 0
134
  for content in contents:
135
  try:
136
  page_content = content['page_content']
@@ -138,38 +136,31 @@ def vectorize(contents, embedding_model):
138
  metadata = {'title': content['title'], 'source': content['link']}
139
  doc = Document(page_content=content['page_content'], metadata=metadata)
140
  documents.append(doc)
141
- total_content_length += len(page_content)
142
  except Exception as e:
143
- print(f"[gray]Error processing content for {content['link']}: {e}")
144
 
145
- # Define a threshold for when to use pre-splitting (e.g., 1 million characters)
146
- pre_split_threshold = 1_000_000
147
 
148
- if total_content_length > pre_split_threshold:
149
- # Use pre-splitting for large datasets
150
- pre_splitter = RecursiveCharacterTextSplitter(
151
- chunk_size=2000,
152
- chunk_overlap=200,
153
- length_function=len,
154
- )
155
- documents = pre_splitter.split_documents(documents)
156
 
157
- semantic_chunker = SemanticChunker(embedding_model, breakpoint_threshold_type="percentile")
158
-
159
  vector_store = None
160
- batch_size = 200 # Adjust this value if needed
161
 
162
- for i in range(0, len(documents), batch_size):
163
- batch = documents[i:i+batch_size]
164
-
165
- # Split each document in the batch using SemanticChunker
166
- chunked_docs = []
167
- for doc in batch:
168
- chunked_docs.extend(semantic_chunker.split_documents([doc]))
169
 
170
  if vector_store is None:
171
- vector_store = FAISS.from_documents(chunked_docs, embedding_model)
172
  else:
173
- vector_store.add_documents(chunked_docs)
 
 
 
 
 
 
174
 
175
  return vector_store
 
8
  from selenium.common.exceptions import TimeoutException
9
  from langchain_core.documents.base import Document
10
  from langchain_experimental.text_splitter import SemanticChunker
11
+ from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
 
12
  from langchain_community.vectorstores.faiss import FAISS
13
  from langsmith import traceable
14
  import requests
 
129
  @traceable(run_type="embedding")
130
  def vectorize(contents, embedding_model):
131
  documents = []
 
132
  for content in contents:
133
  try:
134
  page_content = content['page_content']
 
136
  metadata = {'title': content['title'], 'source': content['link']}
137
  doc = Document(page_content=content['page_content'], metadata=metadata)
138
  documents.append(doc)
 
139
  except Exception as e:
140
+ print(f"Error processing content for {content['link']}: {e}")
141
 
142
+ # Initialize recursive text splitter
143
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
144
 
145
+ # Split documents
146
+ split_documents = text_splitter.split_documents(documents)
 
 
 
 
 
 
147
 
148
+ # Create vector store
 
149
  vector_store = None
150
+ batch_size = 250 # Slightly less than 256 to be safe
151
 
152
+ for i in range(0, len(split_documents), batch_size):
153
+ batch = split_documents[i:i+batch_size]
 
 
 
 
 
154
 
155
  if vector_store is None:
156
+ vector_store = FAISS.from_documents(batch, embedding_model)
157
  else:
158
+ texts = [doc.page_content for doc in batch]
159
+ metadatas = [doc.metadata for doc in batch]
160
+ embeddings = embedding_model.embed_documents(texts)
161
+ vector_store.add_embeddings(
162
+ list(zip(texts, embeddings)),
163
+ metadatas
164
+ )
165
 
166
  return vector_store
web_rag.py CHANGED
@@ -96,6 +96,12 @@ def get_optimized_search_messages(query):
96
  Exmaple:
97
  Question: Write a short linkedin about how the "freakeconomics" book previsions didn't pan out
98
  freakeconomics book predictions failed**
 
 
 
 
 
 
99
  """
100
  )
101
  human_message = HumanMessage(
@@ -293,4 +299,14 @@ def build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k = 10,
293
  def query_rag(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
294
  prompt = build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k=top_k, callbacks = callbacks)
295
  response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
296
- return response.content
 
 
 
 
 
 
 
 
 
 
 
96
  Exmaple:
97
  Question: Write a short linkedin about how the "freakeconomics" book previsions didn't pan out
98
  freakeconomics book predictions failed**
99
+ Example:
100
+ Question: Write an LinkedIn post about startup M&A in the style of Andrew Ng
101
+ startup M&A**
102
+ Example:
103
+ Question: Write a linked post about the current state of M&A for startups. Write in the style of Russ from Silicon Valley TV show.
104
+ startup current state M&A**
105
  """
106
  )
107
  human_message = HumanMessage(
 
299
  def query_rag(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
300
  prompt = build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k=top_k, callbacks = callbacks)
301
  response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
302
+
303
+ # Ensure we're returning a string
304
+ if isinstance(response.content, list):
305
+ # If it's a list, join the elements into a single string
306
+ return ' '.join(str(item) for item in response.content)
307
+ elif isinstance(response.content, str):
308
+ # If it's already a string, return it as is
309
+ return response.content
310
+ else:
311
+ # If it's neither a list nor a string, convert it to a string
312
+ return str(response.content)