juliaturc commited on
Commit
2db1bb0
·
2 Parent(s): 40b4763 5b5303c

Merge pull request #13 from Storia-AI/julia/marqo

Browse files
Files changed (8) hide show
  1. README.md +55 -27
  2. requirements.txt +1 -0
  3. src/chat.py +14 -23
  4. src/chunker.py +25 -21
  5. src/embedder.py +63 -22
  6. src/index.py +51 -28
  7. src/repo_manager.py +7 -21
  8. src/vector_store.py +59 -6
README.md CHANGED
@@ -7,40 +7,68 @@
7
  **Ok, but why chat with a codebase?**
8
 
9
  Sometimes you just want to learn how a codebase works and how to integrate it, without spending hours sifting through
10
- the code itself.
11
 
12
- `repo2vec` is like GitHub Copilot but with the most up-to-date information about your repo.
13
 
14
- Features:
15
  - **Dead-simple set-up.** Run *two scripts* and you have a functional chat interface for your code. That's really it.
16
  - **Heavily documented answers.** Every response shows where in the code the context for the answer was pulled from. Let's build trust in the AI.
17
  - **Plug-and-play.** Want to improve the algorithms powering the code understanding/generation? We've made every component of the pipeline easily swappable. Customize to your heart's content.
18
 
19
- Here are the two scripts you need to run:
20
- ```
21
- pip install -r requirements.txt
22
 
23
- export GITHUB_REPO_NAME=...
24
- export OPENAI_API_KEY=...
25
- export PINECONE_API_KEY=...
26
- export PINECONE_INDEX_NAME=...
27
 
28
- python src/index.py $GITHUB_REPO_NAME --pinecone_index_name=$PINECONE_INDEX_NAME
29
- python src/chat.py $GITHUB_REPO_NAME --pinecone_index_name=$PINECONE_INDEX_NAME
30
- ```
31
- This will index your entire codebase in a vector DB, then bring up a `gradio` app where you can ask questions about it.
 
 
32
 
33
- The assistant responses always include GitHub links to the documents retrieved for each query.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- If you want to publicly host your chat experience, set `--share=true`:
36
- ```
37
- python src/chat.py $GITHUB_REPO_NAME --share=true ...
38
  ```
 
39
 
40
- That's it.
 
 
 
 
 
41
 
42
- Here is, for example, a conversation about the repo [Storia-AI/image-eval](https://github.com/Storia-AI/image-eval):
43
- ![screenshot](assets/chat_screenshot.png)
44
 
45
  # Peeking under the hood
46
 
@@ -50,10 +78,11 @@ The `src/index.py` script performs the following steps:
50
  - Make sure to set the `GITHUB_TOKEN` environment variable for private repositories.
51
  2. **Chunks files**. See [Chunker](src/chunker.py).
52
  - For code files, we implement a special `CodeChunker` that takes the parse tree into account.
53
- 3. **Batch-embeds chunks**. See [Embedder](src/embedder.py).
54
- - By default, we use OpenAI's [batch embedding API](https://platform.openai.com/docs/guides/batch/overview), which is much faster and cheaper than the regular synchronous embedding API.
 
55
  4. **Stores embeddings in a vector store**. See [VectorStore](src/vector_store.py).
56
- - By default, we use [Pinecone](https://pinecone.io) as a vector store, but you can easily plug in your own.
57
 
58
  Note you can specify an inclusion or exclusion set for the file extensions you want indexed. To specify an extension inclusion set, you can add the `--include` flag:
59
  ```
@@ -77,10 +106,9 @@ The sources are conveniently surfaced in the chat and linked directly to GitHub.
77
 
78
  # Want your repository hosted?
79
 
80
- We're working to make all code on the internet searchable and understandable for devs. If you would like help hosting
81
- your repository, we're onboarding a handful of repos onto our infrastructure **for free**.
82
 
83
- You'll get a dedicated url for your repo like `https://sage.storia.ai/[REPO_NAME]`. Just send us a message at [founders@storia.ai](mailto:founders@storia.ai)!
84
 
85
  ![](assets/sage.gif)
86
 
 
7
  **Ok, but why chat with a codebase?**
8
 
9
  Sometimes you just want to learn how a codebase works and how to integrate it, without spending hours sifting through
10
+ the code itself.
11
 
12
+ `repo2vec` is like GitHub Copilot but with the most up-to-date information about your repo.
13
 
14
+ Features:
15
  - **Dead-simple set-up.** Run *two scripts* and you have a functional chat interface for your code. That's really it.
16
  - **Heavily documented answers.** Every response shows where in the code the context for the answer was pulled from. Let's build trust in the AI.
17
  - **Plug-and-play.** Want to improve the algorithms powering the code understanding/generation? We've made every component of the pipeline easily swappable. Customize to your heart's content.
18
 
19
+ # How to run it
20
+ ## Indexing the codebase
21
+ We currently support two options for indexing the codebase:
22
 
23
+ 1. **Locally**, using the open-source [Marqo vector store](https://github.com/marqo-ai/marqo). Marqo is both an embedder (you can choose your favorite embedding model from Hugging Face) and a vector store.
 
 
 
24
 
25
+ You can bring up a Marqo instance using Docker:
26
+ ```
27
+ docker rm -f marqo
28
+ docker pull marqoai/marqo:latest
29
+ docker run --name marqo -it -p 8882:8882 marqoai/marqo:latest
30
+ ```
31
 
32
+ Then, to index your codebase, run:
33
+ ```
34
+ pip install -r requirements.txt
35
+
36
+ python src/index.py
37
+ github-repo-name \ # e.g. Storia-AI/repo2vec
38
+ --embedder_type=marqo \
39
+ --vector_store_type=marqo \
40
+ --index_name=your-index-name
41
+ ```
42
+
43
+ 2. **Using external providers** (OpenAI for embeddings and [Pinecone](https://www.pinecone.io/) for the vector store). To index your codebase, run:
44
+ ```
45
+ pip install -r requirements.txt
46
+
47
+ export OPENAI_API_KEY=...
48
+ export PINECONE_API_KEY=...
49
+
50
+ python src/index.py
51
+ github-repo-name \ # e.g. Storia-AI/repo2vec
52
+ --embedder_type=openai \
53
+ --vector_store_type=pinecone \
54
+ --index_name=your-index-name
55
+ ```
56
+ We are planning on adding more providers soon, so that you can mix and match them. Contributions are also welcome!
57
+
58
+ ## Chatting with the codebase
59
+ To bring a `gradio` app where you can chat with your codebase, simply point it to your vector store:
60
 
 
 
 
61
  ```
62
+ export OPENAI_API_KEY=...
63
 
64
+ python src/chat.py \
65
+ github-repo-name \ # e.g. Storia-AI/repo2vec
66
+ --vector_store_type=marqo \ # or pinecone
67
+ --index_name=your-index-name
68
+ ```
69
+ To get a public URL for your chat app, set `--share=true`.
70
 
71
+ Currently, the chat will use OpenAI's GPT-4, but we are working on adding support for other providers and local LLMs. Stay tuned!
 
72
 
73
  # Peeking under the hood
74
 
 
78
  - Make sure to set the `GITHUB_TOKEN` environment variable for private repositories.
79
  2. **Chunks files**. See [Chunker](src/chunker.py).
80
  - For code files, we implement a special `CodeChunker` that takes the parse tree into account.
81
+ 3. **Batch-embeds chunks**. See [Embedder](src/embedder.py). We currently support:
82
+ - [Marqo](https://github.com/marqo-ai/marqo) as an embedder, which allows you to specify your favorite Hugging Face embedding model;
83
+ - OpenAI's [batch embedding API](https://platform.openai.com/docs/guides/batch/overview), which is much faster and cheaper than the regular synchronous embedding API.
84
  4. **Stores embeddings in a vector store**. See [VectorStore](src/vector_store.py).
85
+ - We currently support [Marqo](https://github.com/marqo-ai/marqo) and [Pinecone](https://pinecone.io), but you can easily plug in your own.
86
 
87
  Note you can specify an inclusion or exclusion set for the file extensions you want indexed. To specify an extension inclusion set, you can add the `--include` flag:
88
  ```
 
106
 
107
  # Want your repository hosted?
108
 
109
+ We're working to make all code on the internet searchable and understandable for devs. You can check out our early product, [Code Sage](https://sage.storia.ai). We pre-indexed a slew of OSS repos, and you can index your desired ones by simply pasting a GitHub URL.
 
110
 
111
+ If you're the maintainer of an OSS repo and would like a dedicated page on Code Sage (e.g. `sage.storia.ai/your-repo`), then send us a message at [founders@storia.ai](mailto:founders@storia.ai). We'll do it for free!
112
 
113
  ![](assets/sage.gif)
114
 
requirements.txt CHANGED
@@ -4,6 +4,7 @@ gradio==4.42.0
4
  langchain==0.2.14
5
  langchain-community==0.2.12
6
  langchain-openai==0.1.22
 
7
  nbformat==5.10.4
8
  openai==1.42.0
9
  pinecone==5.0.1
 
4
  langchain==0.2.14
5
  langchain-community==0.2.12
6
  langchain-openai==0.1.22
7
+ marqo==3.7.0
8
  nbformat==5.10.4
9
  openai==1.42.0
10
  pinecone==5.0.1
src/chat.py CHANGED
@@ -5,16 +5,16 @@ You must run main.py first in order to index the codebase into a vector store.
5
 
6
  import argparse
7
 
8
- from dotenv import load_dotenv
9
-
10
  import gradio as gr
11
- from langchain.chains import create_history_aware_retriever, create_retrieval_chain
 
 
12
  from langchain.chains.combine_documents import create_stuff_documents_chain
13
  from langchain.schema import AIMessage, HumanMessage
14
- from langchain_community.vectorstores import Pinecone
15
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
16
- from langchain_openai import ChatOpenAI, OpenAIEmbeddings
17
 
 
18
  from repo_manager import RepoManager
19
 
20
  load_dotenv()
@@ -23,14 +23,7 @@ load_dotenv()
23
  def build_rag_chain(args):
24
  """Builds a RAG chain via LangChain."""
25
  llm = ChatOpenAI(model=args.openai_model)
26
-
27
- vectorstore = Pinecone.from_existing_index(
28
- index_name=args.pinecone_index_name,
29
- embedding=OpenAIEmbeddings(),
30
- namespace=args.repo_id,
31
- )
32
-
33
- retriever = vectorstore.as_retriever()
34
 
35
  # Prompt to contextualize the latest query based on the chat history.
36
  contextualize_q_system_prompt = (
@@ -45,9 +38,7 @@ def build_rag_chain(args):
45
  ("human", "{input}"),
46
  ]
47
  )
48
- history_aware_retriever = create_history_aware_retriever(
49
- llm, retriever, contextualize_q_prompt
50
- )
51
 
52
  qa_system_prompt = (
53
  f"You are my coding buddy, helping me quickly understand a GitHub repository called {args.repo_id}."
@@ -76,9 +67,7 @@ def append_sources_to_response(response):
76
  # Deduplicate filenames while preserving their order.
77
  filenames = list(dict.fromkeys(filenames))
78
  repo_manager = RepoManager(args.repo_id)
79
- github_links = [
80
- repo_manager.github_link_for_file(filename) for filename in filenames
81
- ]
82
  return response["answer"] + "\n\nSources:\n" + "\n".join(github_links)
83
 
84
 
@@ -90,8 +79,12 @@ if __name__ == "__main__":
90
  default="gpt-4",
91
  help="The OpenAI model to use for response generation",
92
  )
 
 
93
  parser.add_argument(
94
- "--pinecone_index_name", required=True, help="Pinecone index name"
 
 
95
  )
96
  parser.add_argument(
97
  "--share",
@@ -109,9 +102,7 @@ if __name__ == "__main__":
109
  history_langchain_format.append(HumanMessage(content=human))
110
  history_langchain_format.append(AIMessage(content=ai))
111
  history_langchain_format.append(HumanMessage(content=message))
112
- response = rag_chain.invoke(
113
- {"input": message, "chat_history": history_langchain_format}
114
- )
115
  answer = append_sources_to_response(response)
116
  return answer
117
 
 
5
 
6
  import argparse
7
 
 
 
8
  import gradio as gr
9
+ from dotenv import load_dotenv
10
+ from langchain.chains import (create_history_aware_retriever,
11
+ create_retrieval_chain)
12
  from langchain.chains.combine_documents import create_stuff_documents_chain
13
  from langchain.schema import AIMessage, HumanMessage
 
14
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
15
+ from langchain_openai import ChatOpenAI
16
 
17
+ import vector_store
18
  from repo_manager import RepoManager
19
 
20
  load_dotenv()
 
23
  def build_rag_chain(args):
24
  """Builds a RAG chain via LangChain."""
25
  llm = ChatOpenAI(model=args.openai_model)
26
+ retriever = vector_store.build_from_args(args).to_langchain().as_retriever()
 
 
 
 
 
 
 
27
 
28
  # Prompt to contextualize the latest query based on the chat history.
29
  contextualize_q_system_prompt = (
 
38
  ("human", "{input}"),
39
  ]
40
  )
41
+ history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
 
 
42
 
43
  qa_system_prompt = (
44
  f"You are my coding buddy, helping me quickly understand a GitHub repository called {args.repo_id}."
 
67
  # Deduplicate filenames while preserving their order.
68
  filenames = list(dict.fromkeys(filenames))
69
  repo_manager = RepoManager(args.repo_id)
70
+ github_links = [repo_manager.github_link_for_file(filename) for filename in filenames]
 
 
71
  return response["answer"] + "\n\nSources:\n" + "\n".join(github_links)
72
 
73
 
 
79
  default="gpt-4",
80
  help="The OpenAI model to use for response generation",
81
  )
82
+ parser.add_argument("--vector_store_type", default="pinecone", choices=["pinecone", "marqo"])
83
+ parser.add_argument("--index_name", required=True, help="Vector store index name")
84
  parser.add_argument(
85
+ "--marqo_url",
86
+ default="http://localhost:8882",
87
+ help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
88
  )
89
  parser.add_argument(
90
  "--share",
 
102
  history_langchain_format.append(HumanMessage(content=human))
103
  history_langchain_format.append(AIMessage(content=ai))
104
  history_langchain_format.append(HumanMessage(content=message))
105
+ response = rag_chain.invoke({"input": message, "chat_history": history_langchain_format})
 
 
106
  answer = append_sources_to_response(response)
107
  return answer
108
 
src/chunker.py CHANGED
@@ -1,12 +1,12 @@
1
  """Chunker abstraction and implementations."""
2
 
3
  import logging
4
- import nbformat
5
  from abc import ABC, abstractmethod
6
  from dataclasses import dataclass
7
  from functools import lru_cache
8
  from typing import List, Optional
9
 
 
10
  import pygments
11
  import tiktoken
12
  from semchunk import chunk as chunk_via_semchunk
@@ -30,11 +30,26 @@ class Chunk:
30
  """The text content to be embedded. Might contain information beyond just the text snippet from the file."""
31
  return self._content
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def populate_content(self, file_content: str):
34
  """Populates the content of the chunk with the file path and file content."""
35
- self._content = (
36
- self.filename + "\n\n" + file_content[self.start_byte : self.end_byte]
37
- )
38
 
39
  def num_tokens(self, tokenizer):
40
  """Counts the number of tokens in the chunk."""
@@ -98,9 +113,7 @@ class CodeChunker(Chunker):
98
 
99
  if not node.children:
100
  # This is a leaf node, but it's too long. We'll have to split it with a text tokenizer.
101
- return self.text_chunker.chunk(
102
- filename, file_content[node.start_byte : node.end_byte]
103
- )
104
 
105
  chunks = []
106
  for child in node.children:
@@ -116,11 +129,7 @@ class CodeChunker(Chunker):
116
  for chunk in chunks:
117
  if not merged_chunks:
118
  merged_chunks.append(chunk)
119
- elif (
120
- merged_chunks[-1].num_tokens(self.tokenizer)
121
- + chunk.num_tokens(self.tokenizer)
122
- < self.max_tokens - 50
123
- ):
124
  # There's a good chance that merging these two chunks will be under the token limit. We're not 100% sure
125
  # at this point, because tokenization is not necessarily additive.
126
  merged = Chunk(
@@ -186,9 +195,7 @@ class CodeChunker(Chunker):
186
  # a bug in the code.
187
  assert chunk.content
188
  size = chunk.num_tokens(self.tokenizer)
189
- assert (
190
- size <= self.max_tokens
191
- ), f"Chunk size {size} exceeds max_tokens {self.max_tokens}."
192
 
193
  return chunks
194
 
@@ -200,17 +207,13 @@ class TextChunker(Chunker):
200
  self.max_tokens = max_tokens
201
 
202
  tokenizer = tiktoken.get_encoding("cl100k_base")
203
- self.count_tokens = lambda text: len(
204
- tokenizer.encode(text, disallowed_special=())
205
- )
206
 
207
  def chunk(self, file_path: str, file_content: str) -> List[Chunk]:
208
  """Chunks a text file into smaller pieces."""
209
  # We need to allocate some tokens for the filename, which is part of the chunk content.
210
  extra_tokens = self.count_tokens(file_path + "\n\n")
211
- text_chunks = chunk_via_semchunk(
212
- file_content, self.max_tokens - extra_tokens, self.count_tokens
213
- )
214
 
215
  chunks = []
216
  start = 0
@@ -235,6 +238,7 @@ class IPYNBChunker(Chunker):
235
 
236
  Based on https://github.com/GoogleCloudPlatform/generative-ai/blob/main/language/code/code_retrieval_augmented_generation.ipynb
237
  """
 
238
  def __init__(self, code_chunker: CodeChunker):
239
  self.code_chunker = code_chunker
240
 
 
1
  """Chunker abstraction and implementations."""
2
 
3
  import logging
 
4
  from abc import ABC, abstractmethod
5
  from dataclasses import dataclass
6
  from functools import lru_cache
7
  from typing import List, Optional
8
 
9
+ import nbformat
10
  import pygments
11
  import tiktoken
12
  from semchunk import chunk as chunk_via_semchunk
 
30
  """The text content to be embedded. Might contain information beyond just the text snippet from the file."""
31
  return self._content
32
 
33
+ @property
34
+ def to_metadata(self):
35
+ """Converts the chunk to a dictionary that can be passed to a vector store."""
36
+ # Some vector stores require the IDs to be ASCII.
37
+ filename_ascii = self.filename.encode("ascii", "ignore").decode("ascii")
38
+ return {
39
+ # Some vector stores require the IDs to be ASCII.
40
+ "id": f"{filename_ascii}_{self.start_byte}_{self.end_byte}",
41
+ "filename": self.filename,
42
+ "start_byte": self.start_byte,
43
+ "end_byte": self.end_byte,
44
+ # Note to developer: When choosing a large chunk size, you might exceed the vector store's metadata
45
+ # size limit. In that case, you can simply store the start/end bytes above, and fetch the content
46
+ # directly from the repository when needed.
47
+ "text": self.content,
48
+ }
49
+
50
  def populate_content(self, file_content: str):
51
  """Populates the content of the chunk with the file path and file content."""
52
+ self._content = self.filename + "\n\n" + file_content[self.start_byte : self.end_byte]
 
 
53
 
54
  def num_tokens(self, tokenizer):
55
  """Counts the number of tokens in the chunk."""
 
113
 
114
  if not node.children:
115
  # This is a leaf node, but it's too long. We'll have to split it with a text tokenizer.
116
+ return self.text_chunker.chunk(filename, file_content[node.start_byte : node.end_byte])
 
 
117
 
118
  chunks = []
119
  for child in node.children:
 
129
  for chunk in chunks:
130
  if not merged_chunks:
131
  merged_chunks.append(chunk)
132
+ elif merged_chunks[-1].num_tokens(self.tokenizer) + chunk.num_tokens(self.tokenizer) < self.max_tokens - 50:
 
 
 
 
133
  # There's a good chance that merging these two chunks will be under the token limit. We're not 100% sure
134
  # at this point, because tokenization is not necessarily additive.
135
  merged = Chunk(
 
195
  # a bug in the code.
196
  assert chunk.content
197
  size = chunk.num_tokens(self.tokenizer)
198
+ assert size <= self.max_tokens, f"Chunk size {size} exceeds max_tokens {self.max_tokens}."
 
 
199
 
200
  return chunks
201
 
 
207
  self.max_tokens = max_tokens
208
 
209
  tokenizer = tiktoken.get_encoding("cl100k_base")
210
+ self.count_tokens = lambda text: len(tokenizer.encode(text, disallowed_special=()))
 
 
211
 
212
  def chunk(self, file_path: str, file_content: str) -> List[Chunk]:
213
  """Chunks a text file into smaller pieces."""
214
  # We need to allocate some tokens for the filename, which is part of the chunk content.
215
  extra_tokens = self.count_tokens(file_path + "\n\n")
216
+ text_chunks = chunk_via_semchunk(file_content, self.max_tokens - extra_tokens, self.count_tokens)
 
 
217
 
218
  chunks = []
219
  start = 0
 
238
 
239
  Based on https://github.com/GoogleCloudPlatform/generative-ai/blob/main/language/code/code_retrieval_augmented_generation.ipynb
240
  """
241
+
242
  def __init__(self, code_chunker: CodeChunker):
243
  self.code_chunker = code_chunker
244
 
src/embedder.py CHANGED
@@ -7,6 +7,7 @@ from abc import ABC, abstractmethod
7
  from collections import Counter
8
  from typing import Dict, Generator, List, Tuple
9
 
 
10
  from openai import OpenAI
11
 
12
  from chunker import Chunk, Chunker
@@ -19,7 +20,7 @@ class BatchEmbedder(ABC):
19
  """Abstract class for batch embedding of a repository."""
20
 
21
  @abstractmethod
22
- def embed_repo(self, chunks_per_batch: int):
23
  """Issues batch embedding jobs for the entire repository."""
24
 
25
  @abstractmethod
@@ -62,7 +63,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
62
  openai_batch_id = self._issue_job_for_chunks(
63
  sub_batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}"
64
  )
65
- self.openai_batch_ids[openai_batch_id] = self._metadata_for_chunks(sub_batch)
66
  if max_embedding_jobs and len(self.openai_batch_ids) >= max_embedding_jobs:
67
  logging.info("Reached the maximum number of embedding jobs. Stopping.")
68
  return
@@ -71,7 +72,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
71
  # Finally, commit the last batch.
72
  if batch:
73
  openai_batch_id = self._issue_job_for_chunks(batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}")
74
- self.openai_batch_ids[openai_batch_id] = self._metadata_for_chunks(batch)
75
  logging.info("Issued %d jobs for %d chunks.", len(self.openai_batch_ids), chunk_count)
76
 
77
  # Save the job IDs to a file, just in case this script is terminated by mistake.
@@ -171,22 +172,62 @@ class OpenAIBatchEmbedder(BatchEmbedder):
171
  },
172
  }
173
 
174
- @staticmethod
175
- def _metadata_for_chunks(chunks):
176
- metadata = []
177
- for chunk in chunks:
178
- filename_ascii = chunk.filename.encode("ascii", "ignore").decode("ascii")
179
- metadata.append(
180
- {
181
- # Some vector stores require the IDs to be ASCII.
182
- "id": f"{filename_ascii}_{chunk.start_byte}_{chunk.end_byte}",
183
- "filename": chunk.filename,
184
- "start_byte": chunk.start_byte,
185
- "end_byte": chunk.end_byte,
186
- # Note to developer: When choosing a large chunk size, you might exceed the vector store's metadata
187
- # size limit. In that case, you can simply store the start/end bytes above, and fetch the content
188
- # directly from the repository when needed.
189
- "text": chunk.content,
190
- }
191
- )
192
- return metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from collections import Counter
8
  from typing import Dict, Generator, List, Tuple
9
 
10
+ import marqo
11
  from openai import OpenAI
12
 
13
  from chunker import Chunk, Chunker
 
20
  """Abstract class for batch embedding of a repository."""
21
 
22
  @abstractmethod
23
+ def embed_repo(self, chunks_per_batch: int, max_embedding_jobs: int = None):
24
  """Issues batch embedding jobs for the entire repository."""
25
 
26
  @abstractmethod
 
63
  openai_batch_id = self._issue_job_for_chunks(
64
  sub_batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}"
65
  )
66
+ self.openai_batch_ids[openai_batch_id] = [chunk.to_metadata for chunk in sub_batch]
67
  if max_embedding_jobs and len(self.openai_batch_ids) >= max_embedding_jobs:
68
  logging.info("Reached the maximum number of embedding jobs. Stopping.")
69
  return
 
72
  # Finally, commit the last batch.
73
  if batch:
74
  openai_batch_id = self._issue_job_for_chunks(batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}")
75
+ self.openai_batch_ids[openai_batch_id] = [chunk.to_metadata for chunk in batch]
76
  logging.info("Issued %d jobs for %d chunks.", len(self.openai_batch_ids), chunk_count)
77
 
78
  # Save the job IDs to a file, just in case this script is terminated by mistake.
 
172
  },
173
  }
174
 
175
+
176
+ class MarqoEmbedder(BatchEmbedder):
177
+ """Embedder that uses the open-source Marqo vector search engine.
178
+
179
+ Embeddings can be stored locally (in which case `url` the constructor should point to localhost) or in the cloud.
180
+ """
181
+
182
+ def __init__(self, repo_manager: RepoManager, chunker: Chunker, index_name: str, url: str, model="hf/e5-base-v2"):
183
+ self.repo_manager = repo_manager
184
+ self.chunker = chunker
185
+ self.client = marqo.Client(url=url)
186
+ self.index = self.client.index(index_name)
187
+
188
+ all_index_names = [result["indexName"] for result in self.client.get_indexes()["results"]]
189
+ if not index_name in all_index_names:
190
+ self.client.create_index(index_name, model=model)
191
+
192
+ def embed_repo(self, chunks_per_batch: int, max_embedding_jobs: int = None):
193
+ """Issues batch embedding jobs for the entire repository."""
194
+ if chunks_per_batch > 64:
195
+ raise ValueError("Marqo enforces a limit of 64 chunks per batch.")
196
+
197
+ chunk_count = 0
198
+ batch = []
199
+
200
+ for filepath, content in self.repo_manager.walk():
201
+ chunks = self.chunker.chunk(filepath, content)
202
+ chunk_count += len(chunks)
203
+ batch.extend(chunks)
204
+
205
+ if len(batch) > chunks_per_batch:
206
+ for i in range(0, len(batch), chunks_per_batch):
207
+ sub_batch = batch[i : i + chunks_per_batch]
208
+ logging.info("Indexing %d chunks...", len(sub_batch))
209
+ self.index.add_documents(
210
+ documents=[chunk.to_metadata for chunk in sub_batch],
211
+ tensor_fields=["text"],
212
+ )
213
+
214
+ if max_embedding_jobs and len(self.openai_batch_ids) >= max_embedding_jobs:
215
+ logging.info("Reached the maximum number of embedding jobs. Stopping.")
216
+ return
217
+ batch = []
218
+
219
+ # Finally, commit the last batch.
220
+ if batch:
221
+ self.index.add_documents(documents=[chunk.to_metadata for chunk in batch], tensor_fields=["text"])
222
+ logging.info(f"Successfully embedded {chunk_count} chunks.")
223
+
224
+ def embeddings_are_ready(self) -> bool:
225
+ """Checks whether the batch embedding jobs are done."""
226
+ # Marqo indexes documents synchronously, so once embed_repo() returns, the embeddings are ready.
227
+ return True
228
+
229
+ def download_embeddings(self) -> Generator[Vector, None, None]:
230
+ """Yields (chunk_metadata, embedding) pairs for each chunk in the repository."""
231
+ # Marqo stores embeddings as they are created, so they're already in the vector store. No need to download them
232
+ # as we would with e.g. OpenAI, Cohere, or some other cloud-based embedding service.
233
+ return []
src/index.py CHANGED
@@ -5,19 +5,14 @@ import logging
5
  import time
6
 
7
  from chunker import UniversalChunker
8
- from embedder import OpenAIBatchEmbedder
9
  from repo_manager import RepoManager
10
- from vector_store import PineconeVectorStore
11
 
12
  logging.basicConfig(level=logging.INFO)
13
 
14
- OPENAI_EMBEDDING_SIZE = 1536
15
- MAX_TOKENS_PER_CHUNK = (
16
- 8192 # The ADA embedder from OpenAI has a maximum of 8192 tokens.
17
- )
18
- MAX_CHUNKS_PER_BATCH = (
19
- 2048 # The OpenAI batch embedding API enforces a maximum of 2048 chunks per batch.
20
- )
21
  MAX_TOKENS_PER_JOB = 3_000_000 # The OpenAI batch embedding API enforces a maximum of 3M tokens processed at once.
22
 
23
 
@@ -29,6 +24,8 @@ def _read_extensions(path):
29
  def main():
30
  parser = argparse.ArgumentParser(description="Batch-embeds a repository")
31
  parser.add_argument("repo_id", help="The ID of the repository to index")
 
 
32
  parser.add_argument(
33
  "--local_dir",
34
  default="repos",
@@ -41,11 +38,12 @@ def main():
41
  help="https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.",
42
  )
43
  parser.add_argument(
44
- "--chunks_per_batch", type=int, default=2000, help="Maximum chunks per batch"
45
- )
46
- parser.add_argument(
47
- "--pinecone_index_name", required=True, help="Pinecone index name"
48
  )
 
49
  parser.add_argument(
50
  "--include",
51
  help="Path to a file containing a list of extensions to include. One extension per line.",
@@ -56,22 +54,37 @@ def main():
56
  help="Path to a file containing a list of extensions to exclude. One extension per line.",
57
  )
58
  parser.add_argument(
59
- "--max_embedding_jobs", type=int,
 
60
  help="Maximum number of embedding jobs to run. Specifying this might result in "
61
  "indexing only part of the repository, but prevents you from burning through OpenAI credits.",
62
  )
63
-
 
 
 
 
 
 
 
 
 
64
  args = parser.parse_args()
65
 
66
- # Validate the arguments.
 
 
 
 
 
 
 
 
 
67
  if args.tokens_per_chunk > MAX_TOKENS_PER_CHUNK:
68
- parser.error(
69
- f"The maximum number of tokens per chunk is {MAX_TOKENS_PER_CHUNK}."
70
- )
71
  if args.chunks_per_batch > MAX_CHUNKS_PER_BATCH:
72
- parser.error(
73
- f"The maximum number of chunks per batch is {MAX_CHUNKS_PER_BATCH}."
74
- )
75
  if args.tokens_per_chunk * args.chunks_per_batch >= MAX_TOKENS_PER_JOB:
76
  parser.error(f"The maximum number of chunks per job is {MAX_TOKENS_PER_JOB}.")
77
  if args.include and args.exclude:
@@ -91,9 +104,23 @@ def main():
91
 
92
  logging.info("Issuing embedding jobs...")
93
  chunker = UniversalChunker(max_tokens=args.tokens_per_chunk)
94
- embedder = OpenAIBatchEmbedder(repo_manager, chunker, args.local_dir)
 
 
 
 
 
 
 
 
 
95
  embedder.embed_repo(args.chunks_per_batch, args.max_embedding_jobs)
96
 
 
 
 
 
 
97
  logging.info("Waiting for embeddings to be ready...")
98
  while not embedder.embeddings_are_ready():
99
  logging.info("Sleeping for 30 seconds...")
@@ -101,11 +128,7 @@ def main():
101
 
102
  logging.info("Moving embeddings to the vector store...")
103
  # Note to developer: Replace this with your preferred vector store.
104
- vector_store = PineconeVectorStore(
105
- index_name=args.pinecone_index_name,
106
- dimension=OPENAI_EMBEDDING_SIZE,
107
- namespace=repo_manager.repo_id,
108
- )
109
  vector_store.ensure_exists()
110
  vector_store.upsert(embedder.download_embeddings())
111
  logging.info("Done!")
 
5
  import time
6
 
7
  from chunker import UniversalChunker
8
+ from embedder import MarqoEmbedder, OpenAIBatchEmbedder
9
  from repo_manager import RepoManager
10
+ from vector_store import build_from_args
11
 
12
  logging.basicConfig(level=logging.INFO)
13
 
14
+ MAX_TOKENS_PER_CHUNK = 8192 # The ADA embedder from OpenAI has a maximum of 8192 tokens.
15
+ MAX_CHUNKS_PER_BATCH = 2048 # The OpenAI batch embedding API enforces a maximum of 2048 chunks per batch.
 
 
 
 
 
16
  MAX_TOKENS_PER_JOB = 3_000_000 # The OpenAI batch embedding API enforces a maximum of 3M tokens processed at once.
17
 
18
 
 
24
  def main():
25
  parser = argparse.ArgumentParser(description="Batch-embeds a repository")
26
  parser.add_argument("repo_id", help="The ID of the repository to index")
27
+ parser.add_argument("--embedder_type", default="openai", choices=["openai", "marqo"])
28
+ parser.add_argument("--vector_store_type", default="pinecone", choices=["pinecone", "marqo"])
29
  parser.add_argument(
30
  "--local_dir",
31
  default="repos",
 
38
  help="https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.",
39
  )
40
  parser.add_argument(
41
+ "--chunks_per_batch",
42
+ type=int,
43
+ default=2000,
44
+ help="Maximum chunks per batch. We recommend 2000 for the OpenAI embedder. Marqo enforces a limit of 64.",
45
  )
46
+ parser.add_argument("--index_name", required=True, help="Vector store index name")
47
  parser.add_argument(
48
  "--include",
49
  help="Path to a file containing a list of extensions to include. One extension per line.",
 
54
  help="Path to a file containing a list of extensions to exclude. One extension per line.",
55
  )
56
  parser.add_argument(
57
+ "--max_embedding_jobs",
58
+ type=int,
59
  help="Maximum number of embedding jobs to run. Specifying this might result in "
60
  "indexing only part of the repository, but prevents you from burning through OpenAI credits.",
61
  )
62
+ parser.add_argument(
63
+ "--marqo_url",
64
+ default="http://localhost:8882",
65
+ help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
66
+ )
67
+ parser.add_argument(
68
+ "--marqo_embedding_model",
69
+ default="hf/e5-base-v2",
70
+ help="The embedding model to use for Marqo.",
71
+ )
72
  args = parser.parse_args()
73
 
74
+ # Validate embedder and vector store compatibility.
75
+ if args.embedder_type == "openai" and args.vector_store_type != "pinecone":
76
+ parser.error("When using OpenAI embedder, the vector store type must be Pinecone.")
77
+ if args.embedder_type == "marqo" and args.vector_store_type != "marqo":
78
+ parser.error("When using the marqo embedder, the vector store type must also be marqo.")
79
+ if args.embedder_type == "marqo" and args.chunks_per_batch > 64:
80
+ args.chunks_per_batch = 64
81
+ logging.warning("Marqo enforces a limit of 64 chunks per batch. Setting --chunks_per_batch to 64.")
82
+
83
+ # Validate other arguments.
84
  if args.tokens_per_chunk > MAX_TOKENS_PER_CHUNK:
85
+ parser.error(f"The maximum number of tokens per chunk is {MAX_TOKENS_PER_CHUNK}.")
 
 
86
  if args.chunks_per_batch > MAX_CHUNKS_PER_BATCH:
87
+ parser.error(f"The maximum number of chunks per batch is {MAX_CHUNKS_PER_BATCH}.")
 
 
88
  if args.tokens_per_chunk * args.chunks_per_batch >= MAX_TOKENS_PER_JOB:
89
  parser.error(f"The maximum number of chunks per job is {MAX_TOKENS_PER_JOB}.")
90
  if args.include and args.exclude:
 
104
 
105
  logging.info("Issuing embedding jobs...")
106
  chunker = UniversalChunker(max_tokens=args.tokens_per_chunk)
107
+
108
+ if args.embedder_type == "openai":
109
+ embedder = OpenAIBatchEmbedder(repo_manager, chunker, args.local_dir)
110
+ elif args.embedder_type == "marqo":
111
+ embedder = MarqoEmbedder(
112
+ repo_manager, chunker, index_name=args.index_name, url=args.marqo_url, model=args.marqo_embedding_model
113
+ )
114
+ else:
115
+ raise ValueError(f"Unrecognized embedder type {args.embedder_type}")
116
+
117
  embedder.embed_repo(args.chunks_per_batch, args.max_embedding_jobs)
118
 
119
+ if args.vector_store_type == "marqo":
120
+ # Marqo computes embeddings and stores them in the vector store at once, so we're done.
121
+ logging.info("Done!")
122
+ return
123
+
124
  logging.info("Waiting for embeddings to be ready...")
125
  while not embedder.embeddings_are_ready():
126
  logging.info("Sleeping for 30 seconds...")
 
128
 
129
  logging.info("Moving embeddings to the vector store...")
130
  # Note to developer: Replace this with your preferred vector store.
131
+ vector_store = build_from_args(args)
 
 
 
 
132
  vector_store.ensure_exists()
133
  vector_store.upsert(embedder.download_embeddings())
134
  logging.info("Done!")
src/repo_manager.py CHANGED
@@ -35,9 +35,7 @@ class RepoManager:
35
  @cached_property
36
  def is_public(self) -> bool:
37
  """Checks whether a GitHub repository is publicly visible."""
38
- response = requests.get(
39
- f"https://api.github.com/repos/{self.repo_id}", timeout=10
40
- )
41
  # Note that the response will be 404 for both private and non-existent repos.
42
  return response.status_code == 200
43
 
@@ -50,17 +48,13 @@ class RepoManager:
50
  if self.access_token:
51
  headers["Authorization"] = f"token {self.access_token}"
52
 
53
- response = requests.get(
54
- f"https://api.github.com/repos/{self.repo_id}", headers=headers
55
- )
56
  if response.status_code == 200:
57
  branch = response.json().get("default_branch", "main")
58
  else:
59
  # This happens sometimes when we exceed the Github rate limit. The best bet in this case is to assume the
60
  # most common naming for the default branch ("main").
61
- logging.warn(
62
- f"Unable to fetch default branch for {self.repo_id}: {response.text}"
63
- )
64
  branch = "main"
65
  return branch
66
 
@@ -81,9 +75,7 @@ class RepoManager:
81
  try:
82
  Repo.clone_from(clone_url, self.local_path, depth=1, single_branch=True)
83
  except GitCommandError as e:
84
- logging.error(
85
- "Unable to clone %s from %s. Error: %s", self.repo_id, clone_url, e
86
- )
87
  return False
88
  return True
89
 
@@ -130,9 +122,7 @@ class RepoManager:
130
  for path in included_file_paths:
131
  f.write(path + "\n")
132
 
133
- excluded_file_paths = set(file_paths).difference(
134
- set(included_file_paths)
135
- )
136
  with open(excluded_log_file, "a") as f:
137
  for path in excluded_file_paths:
138
  f.write(path + "\n")
@@ -142,15 +132,11 @@ class RepoManager:
142
  try:
143
  contents = f.read()
144
  except UnicodeDecodeError:
145
- logging.warning(
146
- "Unable to decode file %s. Skipping.", file_path
147
- )
148
  continue
149
  yield file_path[len(self.local_dir) + 1 :], contents
150
 
151
  def github_link_for_file(self, file_path: str) -> str:
152
  """Converts a repository file path to a GitHub link."""
153
  file_path = file_path[len(self.repo_id) :]
154
- return (
155
- f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
156
- )
 
35
  @cached_property
36
  def is_public(self) -> bool:
37
  """Checks whether a GitHub repository is publicly visible."""
38
+ response = requests.get(f"https://api.github.com/repos/{self.repo_id}", timeout=10)
 
 
39
  # Note that the response will be 404 for both private and non-existent repos.
40
  return response.status_code == 200
41
 
 
48
  if self.access_token:
49
  headers["Authorization"] = f"token {self.access_token}"
50
 
51
+ response = requests.get(f"https://api.github.com/repos/{self.repo_id}", headers=headers)
 
 
52
  if response.status_code == 200:
53
  branch = response.json().get("default_branch", "main")
54
  else:
55
  # This happens sometimes when we exceed the Github rate limit. The best bet in this case is to assume the
56
  # most common naming for the default branch ("main").
57
+ logging.warn(f"Unable to fetch default branch for {self.repo_id}: {response.text}")
 
 
58
  branch = "main"
59
  return branch
60
 
 
75
  try:
76
  Repo.clone_from(clone_url, self.local_path, depth=1, single_branch=True)
77
  except GitCommandError as e:
78
+ logging.error("Unable to clone %s from %s. Error: %s", self.repo_id, clone_url, e)
 
 
79
  return False
80
  return True
81
 
 
122
  for path in included_file_paths:
123
  f.write(path + "\n")
124
 
125
+ excluded_file_paths = set(file_paths).difference(set(included_file_paths))
 
 
126
  with open(excluded_log_file, "a") as f:
127
  for path in excluded_file_paths:
128
  f.write(path + "\n")
 
132
  try:
133
  contents = f.read()
134
  except UnicodeDecodeError:
135
+ logging.warning("Unable to decode file %s. Skipping.", file_path)
 
 
136
  continue
137
  yield file_path[len(self.local_dir) + 1 :], contents
138
 
139
  def github_link_for_file(self, file_path: str) -> str:
140
  """Converts a repository file path to a GitHub link."""
141
  file_path = file_path[len(self.repo_id) :]
142
+ return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
 
 
src/vector_store.py CHANGED
@@ -3,13 +3,19 @@
3
  from abc import ABC, abstractmethod
4
  from typing import Dict, Generator, List, Tuple
5
 
 
 
 
 
6
  from pinecone import Pinecone
7
 
 
8
  Vector = Tuple[Dict, List[float]] # (metadata, embedding)
9
 
10
 
11
  class VectorStore(ABC):
12
  """Abstract class for a vector store."""
 
13
  @abstractmethod
14
  def ensure_exists(self):
15
  """Ensures that the vector store exists. Creates it if it doesn't."""
@@ -29,11 +35,15 @@ class VectorStore(ABC):
29
  if batch:
30
  self.upsert_batch(batch)
31
 
 
 
 
 
32
 
33
  class PineconeVectorStore(VectorStore):
34
  """Vector store implementation using Pinecone."""
35
 
36
- def __init__(self, index_name: str, dimension: int, namespace: str):
37
  self.index_name = index_name
38
  self.dimension = dimension
39
  self.client = Pinecone()
@@ -42,13 +52,56 @@ class PineconeVectorStore(VectorStore):
42
 
43
  def ensure_exists(self):
44
  if self.index_name not in self.client.list_indexes().names():
45
- self.client.create_index(
46
- name=self.index_name, dimension=self.dimension, metric="cosine"
47
- )
48
 
49
  def upsert_batch(self, vectors: List[Vector]):
50
  pinecone_vectors = [
51
- (metadata.get("id", str(i)), embedding, metadata)
52
- for i, (metadata, embedding) in enumerate(vectors)
53
  ]
54
  self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from abc import ABC, abstractmethod
4
  from typing import Dict, Generator, List, Tuple
5
 
6
+ import marqo
7
+ from langchain_community.vectorstores import Marqo
8
+ from langchain_core.documents import Document
9
+ from langchain_openai import OpenAIEmbeddings
10
  from pinecone import Pinecone
11
 
12
+ OPENAI_EMBEDDING_SIZE = 1536
13
  Vector = Tuple[Dict, List[float]] # (metadata, embedding)
14
 
15
 
16
  class VectorStore(ABC):
17
  """Abstract class for a vector store."""
18
+
19
  @abstractmethod
20
  def ensure_exists(self):
21
  """Ensures that the vector store exists. Creates it if it doesn't."""
 
35
  if batch:
36
  self.upsert_batch(batch)
37
 
38
+ @abstractmethod
39
+ def to_langchain(self):
40
+ """Converts the vector store to a LangChain vector store object."""
41
+
42
 
43
  class PineconeVectorStore(VectorStore):
44
  """Vector store implementation using Pinecone."""
45
 
46
+ def __init__(self, index_name: str, namespace: str, dimension: int = OPENAI_EMBEDDING_SIZE):
47
  self.index_name = index_name
48
  self.dimension = dimension
49
  self.client = Pinecone()
 
52
 
53
  def ensure_exists(self):
54
  if self.index_name not in self.client.list_indexes().names():
55
+ self.client.create_index(name=self.index_name, dimension=self.dimension, metric="cosine")
 
 
56
 
57
  def upsert_batch(self, vectors: List[Vector]):
58
  pinecone_vectors = [
59
+ (metadata.get("id", str(i)), embedding, metadata) for i, (metadata, embedding) in enumerate(vectors)
 
60
  ]
61
  self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
62
+
63
+ def to_langchain(self):
64
+ return Pinecone.from_existing_index(
65
+ index_name=self.index_name, embedding=OpenAIEmbeddings(), namespace=self.namespace
66
+ )
67
+
68
+
69
+ class MarqoVectorStore(VectorStore):
70
+ """Vector store implementation using Marqo."""
71
+
72
+ def __init__(self, url: str, index_name: str):
73
+ self.client = marqo.Client(url=url)
74
+ self.index_name = index_name
75
+
76
+ def ensure_exists(self):
77
+ pass
78
+
79
+ def upsert_batch(self, vectors: List[Vector]):
80
+ # Since Marqo is both an embedder and a vector store, the embedder is already doing the upsert.
81
+ pass
82
+
83
+ def to_langchain(self):
84
+ vectorstore = Marqo(client=self.client, index_name=self.index_name)
85
+
86
+ # Monkey-patch the _construct_documents_from_results_without_score method to not expect a "metadata" field in
87
+ # the result, and instead take the "filename" directly from the result.
88
+ def patched_method(self, results):
89
+ documents: List[Document] = []
90
+ for res in results["hits"]:
91
+ documents.append(Document(page_content=res["text"], metadata={"filename": res["filename"]}))
92
+ return documents
93
+
94
+ vectorstore._construct_documents_from_results_without_score = patched_method.__get__(
95
+ vectorstore, vectorstore.__class__
96
+ )
97
+ return vectorstore
98
+
99
+
100
+ def build_from_args(args: dict) -> VectorStore:
101
+ """Builds a vector store from the given command-line arguments."""
102
+ if args.vector_store_type == "pinecone":
103
+ return PineconeVectorStore(index_name=args.index_name, namespace=args.repo_id)
104
+ elif args.vector_store_type == "marqo":
105
+ return MarqoVectorStore(url=args.marqo_url, index_name=args.index_name)
106
+ else:
107
+ raise ValueError(f"Unrecognized vector store type {args.vector_store_type}")