juliaturc commited on
Commit
57007fe
·
1 Parent(s): e8553c3

Support marqo on the inference side and format code.

Browse files
Files changed (6) hide show
  1. src/chat.py +38 -18
  2. src/chunker.py +9 -22
  3. src/embedder.py +7 -15
  4. src/index.py +18 -23
  5. src/repo_manager.py +7 -21
  6. src/vector_store.py +3 -5
src/chat.py CHANGED
@@ -4,14 +4,17 @@ You must run main.py first in order to index the codebase into a vector store.
4
  """
5
 
6
  import argparse
7
-
8
- from dotenv import load_dotenv
9
 
10
  import gradio as gr
11
- from langchain.chains import create_history_aware_retriever, create_retrieval_chain
 
 
 
12
  from langchain.chains.combine_documents import create_stuff_documents_chain
13
  from langchain.schema import AIMessage, HumanMessage
14
- from langchain_community.vectorstores import Pinecone
 
15
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
16
  from langchain_openai import ChatOpenAI, OpenAIEmbeddings
17
 
@@ -24,10 +27,29 @@ def build_rag_chain(args):
24
  """Builds a RAG chain via LangChain."""
25
  llm = ChatOpenAI(model=args.openai_model)
26
 
27
- vectorstore = Pinecone.from_existing_index(
28
- index_name=args.pinecone_index_name,
29
- embedding=OpenAIEmbeddings(),
30
- namespace=args.repo_id,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  )
32
 
33
  retriever = vectorstore.as_retriever()
@@ -45,9 +67,7 @@ def build_rag_chain(args):
45
  ("human", "{input}"),
46
  ]
47
  )
48
- history_aware_retriever = create_history_aware_retriever(
49
- llm, retriever, contextualize_q_prompt
50
- )
51
 
52
  qa_system_prompt = (
53
  f"You are my coding buddy, helping me quickly understand a GitHub repository called {args.repo_id}."
@@ -76,9 +96,7 @@ def append_sources_to_response(response):
76
  # Deduplicate filenames while preserving their order.
77
  filenames = list(dict.fromkeys(filenames))
78
  repo_manager = RepoManager(args.repo_id)
79
- github_links = [
80
- repo_manager.github_link_for_file(filename) for filename in filenames
81
- ]
82
  return response["answer"] + "\n\nSources:\n" + "\n".join(github_links)
83
 
84
 
@@ -90,8 +108,12 @@ if __name__ == "__main__":
90
  default="gpt-4",
91
  help="The OpenAI model to use for response generation",
92
  )
 
 
93
  parser.add_argument(
94
- "--pinecone_index_name", required=True, help="Pinecone index name"
 
 
95
  )
96
  parser.add_argument(
97
  "--share",
@@ -109,9 +131,7 @@ if __name__ == "__main__":
109
  history_langchain_format.append(HumanMessage(content=human))
110
  history_langchain_format.append(AIMessage(content=ai))
111
  history_langchain_format.append(HumanMessage(content=message))
112
- response = rag_chain.invoke(
113
- {"input": message, "chat_history": history_langchain_format}
114
- )
115
  answer = append_sources_to_response(response)
116
  return answer
117
 
 
4
  """
5
 
6
  import argparse
7
+ from typing import List
 
8
 
9
  import gradio as gr
10
+ import marqo
11
+ from dotenv import load_dotenv
12
+ from langchain.chains import (create_history_aware_retriever,
13
+ create_retrieval_chain)
14
  from langchain.chains.combine_documents import create_stuff_documents_chain
15
  from langchain.schema import AIMessage, HumanMessage
16
+ from langchain_community.vectorstores import Marqo, Pinecone
17
+ from langchain_core.documents import Document
18
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
19
  from langchain_openai import ChatOpenAI, OpenAIEmbeddings
20
 
 
27
  """Builds a RAG chain via LangChain."""
28
  llm = ChatOpenAI(model=args.openai_model)
29
 
30
+ if args.vector_store_type == "pinecone":
31
+ vectorstore = Pinecone.from_existing_index(
32
+ index_name=args.pinecone_index_name,
33
+ embedding=OpenAIEmbeddings(),
34
+ namespace=args.repo_id,
35
+ )
36
+ elif args.vector_store_type == "marqo":
37
+ marqo_client = marqo.Client(url=args.marqo_url)
38
+ vectorstore = Marqo(
39
+ client=marqo_client,
40
+ index_name=args.index_name,
41
+ )
42
+
43
+ # Monkey-patch the _construct_documents_from_results_without_score method to not expect a "metadata" field in the
44
+ # result, and instead take the "filename" directly from the result.
45
+ def patched_method(self, results):
46
+ documents: List[Document] = []
47
+ for res in results["hits"]:
48
+ documents.append(Document(page_content=res["text"], metadata={"filename": res["filename"]}))
49
+ return documents
50
+
51
+ vectorstore._construct_documents_from_results_without_score = patched_method.__get__(
52
+ vectorstore, vectorstore.__class__
53
  )
54
 
55
  retriever = vectorstore.as_retriever()
 
67
  ("human", "{input}"),
68
  ]
69
  )
70
+ history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
 
 
71
 
72
  qa_system_prompt = (
73
  f"You are my coding buddy, helping me quickly understand a GitHub repository called {args.repo_id}."
 
96
  # Deduplicate filenames while preserving their order.
97
  filenames = list(dict.fromkeys(filenames))
98
  repo_manager = RepoManager(args.repo_id)
99
+ github_links = [repo_manager.github_link_for_file(filename) for filename in filenames]
 
 
100
  return response["answer"] + "\n\nSources:\n" + "\n".join(github_links)
101
 
102
 
 
108
  default="gpt-4",
109
  help="The OpenAI model to use for response generation",
110
  )
111
+ parser.add_argument("--vector_store_type", default="pinecone", choices=["pinecone", "marqo"])
112
+ parser.add_argument("--index_name", required=True, help="Vector store index name")
113
  parser.add_argument(
114
+ "--marqo_url",
115
+ default="http://localhost:8882",
116
+ help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
117
  )
118
  parser.add_argument(
119
  "--share",
 
131
  history_langchain_format.append(HumanMessage(content=human))
132
  history_langchain_format.append(AIMessage(content=ai))
133
  history_langchain_format.append(HumanMessage(content=message))
134
+ response = rag_chain.invoke({"input": message, "chat_history": history_langchain_format})
 
 
135
  answer = append_sources_to_response(response)
136
  return answer
137
 
src/chunker.py CHANGED
@@ -1,12 +1,12 @@
1
  """Chunker abstraction and implementations."""
2
 
3
  import logging
4
- import nbformat
5
  from abc import ABC, abstractmethod
6
  from dataclasses import dataclass
7
  from functools import lru_cache
8
  from typing import List, Optional
9
 
 
10
  import pygments
11
  import tiktoken
12
  from semchunk import chunk as chunk_via_semchunk
@@ -31,7 +31,7 @@ class Chunk:
31
  return self._content
32
 
33
  @property
34
- def to_dict(self):
35
  """Converts the chunk to a dictionary that can be passed to a vector store."""
36
  # Some vector stores require the IDs to be ASCII.
37
  filename_ascii = self.filename.encode("ascii", "ignore").decode("ascii")
@@ -49,9 +49,7 @@ class Chunk:
49
 
50
  def populate_content(self, file_content: str):
51
  """Populates the content of the chunk with the file path and file content."""
52
- self._content = (
53
- self.filename + "\n\n" + file_content[self.start_byte : self.end_byte]
54
- )
55
 
56
  def num_tokens(self, tokenizer):
57
  """Counts the number of tokens in the chunk."""
@@ -115,9 +113,7 @@ class CodeChunker(Chunker):
115
 
116
  if not node.children:
117
  # This is a leaf node, but it's too long. We'll have to split it with a text tokenizer.
118
- return self.text_chunker.chunk(
119
- filename, file_content[node.start_byte : node.end_byte]
120
- )
121
 
122
  chunks = []
123
  for child in node.children:
@@ -133,11 +129,7 @@ class CodeChunker(Chunker):
133
  for chunk in chunks:
134
  if not merged_chunks:
135
  merged_chunks.append(chunk)
136
- elif (
137
- merged_chunks[-1].num_tokens(self.tokenizer)
138
- + chunk.num_tokens(self.tokenizer)
139
- < self.max_tokens - 50
140
- ):
141
  # There's a good chance that merging these two chunks will be under the token limit. We're not 100% sure
142
  # at this point, because tokenization is not necessarily additive.
143
  merged = Chunk(
@@ -203,9 +195,7 @@ class CodeChunker(Chunker):
203
  # a bug in the code.
204
  assert chunk.content
205
  size = chunk.num_tokens(self.tokenizer)
206
- assert (
207
- size <= self.max_tokens
208
- ), f"Chunk size {size} exceeds max_tokens {self.max_tokens}."
209
 
210
  return chunks
211
 
@@ -217,17 +207,13 @@ class TextChunker(Chunker):
217
  self.max_tokens = max_tokens
218
 
219
  tokenizer = tiktoken.get_encoding("cl100k_base")
220
- self.count_tokens = lambda text: len(
221
- tokenizer.encode(text, disallowed_special=())
222
- )
223
 
224
  def chunk(self, file_path: str, file_content: str) -> List[Chunk]:
225
  """Chunks a text file into smaller pieces."""
226
  # We need to allocate some tokens for the filename, which is part of the chunk content.
227
  extra_tokens = self.count_tokens(file_path + "\n\n")
228
- text_chunks = chunk_via_semchunk(
229
- file_content, self.max_tokens - extra_tokens, self.count_tokens
230
- )
231
 
232
  chunks = []
233
  start = 0
@@ -252,6 +238,7 @@ class IPYNBChunker(Chunker):
252
 
253
  Based on https://github.com/GoogleCloudPlatform/generative-ai/blob/main/language/code/code_retrieval_augmented_generation.ipynb
254
  """
 
255
  def __init__(self, code_chunker: CodeChunker):
256
  self.code_chunker = code_chunker
257
 
 
1
  """Chunker abstraction and implementations."""
2
 
3
  import logging
 
4
  from abc import ABC, abstractmethod
5
  from dataclasses import dataclass
6
  from functools import lru_cache
7
  from typing import List, Optional
8
 
9
+ import nbformat
10
  import pygments
11
  import tiktoken
12
  from semchunk import chunk as chunk_via_semchunk
 
31
  return self._content
32
 
33
  @property
34
+ def to_metadata(self):
35
  """Converts the chunk to a dictionary that can be passed to a vector store."""
36
  # Some vector stores require the IDs to be ASCII.
37
  filename_ascii = self.filename.encode("ascii", "ignore").decode("ascii")
 
49
 
50
  def populate_content(self, file_content: str):
51
  """Populates the content of the chunk with the file path and file content."""
52
+ self._content = self.filename + "\n\n" + file_content[self.start_byte : self.end_byte]
 
 
53
 
54
  def num_tokens(self, tokenizer):
55
  """Counts the number of tokens in the chunk."""
 
113
 
114
  if not node.children:
115
  # This is a leaf node, but it's too long. We'll have to split it with a text tokenizer.
116
+ return self.text_chunker.chunk(filename, file_content[node.start_byte : node.end_byte])
 
 
117
 
118
  chunks = []
119
  for child in node.children:
 
129
  for chunk in chunks:
130
  if not merged_chunks:
131
  merged_chunks.append(chunk)
132
+ elif merged_chunks[-1].num_tokens(self.tokenizer) + chunk.num_tokens(self.tokenizer) < self.max_tokens - 50:
 
 
 
 
133
  # There's a good chance that merging these two chunks will be under the token limit. We're not 100% sure
134
  # at this point, because tokenization is not necessarily additive.
135
  merged = Chunk(
 
195
  # a bug in the code.
196
  assert chunk.content
197
  size = chunk.num_tokens(self.tokenizer)
198
+ assert size <= self.max_tokens, f"Chunk size {size} exceeds max_tokens {self.max_tokens}."
 
 
199
 
200
  return chunks
201
 
 
207
  self.max_tokens = max_tokens
208
 
209
  tokenizer = tiktoken.get_encoding("cl100k_base")
210
+ self.count_tokens = lambda text: len(tokenizer.encode(text, disallowed_special=()))
 
 
211
 
212
  def chunk(self, file_path: str, file_content: str) -> List[Chunk]:
213
  """Chunks a text file into smaller pieces."""
214
  # We need to allocate some tokens for the filename, which is part of the chunk content.
215
  extra_tokens = self.count_tokens(file_path + "\n\n")
216
+ text_chunks = chunk_via_semchunk(file_content, self.max_tokens - extra_tokens, self.count_tokens)
 
 
217
 
218
  chunks = []
219
  start = 0
 
238
 
239
  Based on https://github.com/GoogleCloudPlatform/generative-ai/blob/main/language/code/code_retrieval_augmented_generation.ipynb
240
  """
241
+
242
  def __init__(self, code_chunker: CodeChunker):
243
  self.code_chunker = code_chunker
244
 
src/embedder.py CHANGED
@@ -7,11 +7,11 @@ from abc import ABC, abstractmethod
7
  from collections import Counter
8
  from typing import Dict, Generator, List, Tuple
9
 
 
10
  from openai import OpenAI
11
 
12
  from chunker import Chunk, Chunker
13
  from repo_manager import RepoManager
14
- import marqo
15
 
16
  Vector = Tuple[Dict, List[float]] # (metadata, embedding)
17
 
@@ -63,7 +63,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
63
  openai_batch_id = self._issue_job_for_chunks(
64
  sub_batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}"
65
  )
66
- self.openai_batch_ids[openai_batch_id] = [chunk.to_dict for chunk in sub_batch]
67
  if max_embedding_jobs and len(self.openai_batch_ids) >= max_embedding_jobs:
68
  logging.info("Reached the maximum number of embedding jobs. Stopping.")
69
  return
@@ -72,7 +72,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
72
  # Finally, commit the last batch.
73
  if batch:
74
  openai_batch_id = self._issue_job_for_chunks(batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}")
75
- self.openai_batch_ids[openai_batch_id] = [chunk.to_dict for chunk in batch]
76
  logging.info("Issued %d jobs for %d chunks.", len(self.openai_batch_ids), chunk_count)
77
 
78
  # Save the job IDs to a file, just in case this script is terminated by mistake.
@@ -179,12 +179,7 @@ class MarqoEmbedder(BatchEmbedder):
179
  Embeddings can be stored locally (in which case `url` the constructor should point to localhost) or in the cloud.
180
  """
181
 
182
- def __init__(self,
183
- repo_manager: RepoManager,
184
- chunker: Chunker,
185
- index_name: str,
186
- url: str,
187
- model="hf/e5-base-v2"):
188
  self.repo_manager = repo_manager
189
  self.chunker = chunker
190
  self.client = marqo.Client(url=url)
@@ -212,8 +207,8 @@ class MarqoEmbedder(BatchEmbedder):
212
  sub_batch = batch[i : i + chunks_per_batch]
213
  logging.info("Indexing %d chunks...", len(sub_batch))
214
  self.index.add_documents(
215
- documents=[chunk.to_dict for chunk in sub_batch],
216
- tensor_fields=["text"]
217
  )
218
 
219
  if max_embedding_jobs and len(self.openai_batch_ids) >= max_embedding_jobs:
@@ -223,10 +218,7 @@ class MarqoEmbedder(BatchEmbedder):
223
 
224
  # Finally, commit the last batch.
225
  if batch:
226
- self.index.add_documents(
227
- documents=[chunk.to_dict for chunk in batch],
228
- tensor_fields=["text"]
229
- )
230
  logging.info(f"Successfully embedded {chunk_count} chunks.")
231
 
232
  def embeddings_are_ready(self) -> bool:
 
7
  from collections import Counter
8
  from typing import Dict, Generator, List, Tuple
9
 
10
+ import marqo
11
  from openai import OpenAI
12
 
13
  from chunker import Chunk, Chunker
14
  from repo_manager import RepoManager
 
15
 
16
  Vector = Tuple[Dict, List[float]] # (metadata, embedding)
17
 
 
63
  openai_batch_id = self._issue_job_for_chunks(
64
  sub_batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}"
65
  )
66
+ self.openai_batch_ids[openai_batch_id] = [chunk.to_metadata for chunk in sub_batch]
67
  if max_embedding_jobs and len(self.openai_batch_ids) >= max_embedding_jobs:
68
  logging.info("Reached the maximum number of embedding jobs. Stopping.")
69
  return
 
72
  # Finally, commit the last batch.
73
  if batch:
74
  openai_batch_id = self._issue_job_for_chunks(batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}")
75
+ self.openai_batch_ids[openai_batch_id] = [chunk.to_metadata for chunk in batch]
76
  logging.info("Issued %d jobs for %d chunks.", len(self.openai_batch_ids), chunk_count)
77
 
78
  # Save the job IDs to a file, just in case this script is terminated by mistake.
 
179
  Embeddings can be stored locally (in which case `url` the constructor should point to localhost) or in the cloud.
180
  """
181
 
182
+ def __init__(self, repo_manager: RepoManager, chunker: Chunker, index_name: str, url: str, model="hf/e5-base-v2"):
 
 
 
 
 
183
  self.repo_manager = repo_manager
184
  self.chunker = chunker
185
  self.client = marqo.Client(url=url)
 
207
  sub_batch = batch[i : i + chunks_per_batch]
208
  logging.info("Indexing %d chunks...", len(sub_batch))
209
  self.index.add_documents(
210
+ documents=[chunk.to_metadata for chunk in sub_batch],
211
+ tensor_fields=["text"],
212
  )
213
 
214
  if max_embedding_jobs and len(self.openai_batch_ids) >= max_embedding_jobs:
 
218
 
219
  # Finally, commit the last batch.
220
  if batch:
221
+ self.index.add_documents(documents=[chunk.to_metadata for chunk in batch], tensor_fields=["text"])
 
 
 
222
  logging.info(f"Successfully embedded {chunk_count} chunks.")
223
 
224
  def embeddings_are_ready(self) -> bool:
src/index.py CHANGED
@@ -5,19 +5,15 @@ import logging
5
  import time
6
 
7
  from chunker import UniversalChunker
8
- from embedder import OpenAIBatchEmbedder, MarqoEmbedder
9
  from repo_manager import RepoManager
10
  from vector_store import PineconeVectorStore
11
 
12
  logging.basicConfig(level=logging.INFO)
13
 
14
  OPENAI_EMBEDDING_SIZE = 1536
15
- MAX_TOKENS_PER_CHUNK = (
16
- 8192 # The ADA embedder from OpenAI has a maximum of 8192 tokens.
17
- )
18
- MAX_CHUNKS_PER_BATCH = (
19
- 2048 # The OpenAI batch embedding API enforces a maximum of 2048 chunks per batch.
20
- )
21
  MAX_TOKENS_PER_JOB = 3_000_000 # The OpenAI batch embedding API enforces a maximum of 3M tokens processed at once.
22
 
23
 
@@ -43,11 +39,12 @@ def main():
43
  help="https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.",
44
  )
45
  parser.add_argument(
46
- "--chunks_per_batch", type=int, default=2000, help="Maximum chunks per batch"
47
- )
48
- parser.add_argument(
49
- "--index_name", required=True, help="Vector store index name"
50
  )
 
51
  parser.add_argument(
52
  "--include",
53
  help="Path to a file containing a list of extensions to include. One extension per line.",
@@ -58,7 +55,8 @@ def main():
58
  help="Path to a file containing a list of extensions to exclude. One extension per line.",
59
  )
60
  parser.add_argument(
61
- "--max_embedding_jobs", type=int,
 
62
  help="Maximum number of embedding jobs to run. Specifying this might result in "
63
  "indexing only part of the repository, but prevents you from burning through OpenAI credits.",
64
  )
@@ -79,16 +77,15 @@ def main():
79
  parser.error("When using OpenAI embedder, the vector store type must be Pinecone.")
80
  if args.embedder_type == "marqo" and args.vector_store_type != "marqo":
81
  parser.error("When using the marqo embedder, the vector store type must also be marqo.")
 
 
 
82
 
83
  # Validate other arguments.
84
  if args.tokens_per_chunk > MAX_TOKENS_PER_CHUNK:
85
- parser.error(
86
- f"The maximum number of tokens per chunk is {MAX_TOKENS_PER_CHUNK}."
87
- )
88
  if args.chunks_per_batch > MAX_CHUNKS_PER_BATCH:
89
- parser.error(
90
- f"The maximum number of chunks per batch is {MAX_CHUNKS_PER_BATCH}."
91
- )
92
  if args.tokens_per_chunk * args.chunks_per_batch >= MAX_TOKENS_PER_JOB:
93
  parser.error(f"The maximum number of chunks per job is {MAX_TOKENS_PER_JOB}.")
94
  if args.include and args.exclude:
@@ -112,11 +109,9 @@ def main():
112
  if args.embedder_type == "openai":
113
  embedder = OpenAIBatchEmbedder(repo_manager, chunker, args.local_dir)
114
  elif args.embedder_type == "marqo":
115
- embedder = MarqoEmbedder(repo_manager,
116
- chunker,
117
- index_name=args.index_name,
118
- url=args.marqo_url,
119
- model=args.marqo_embedding_model)
120
  else:
121
  raise ValueError(f"Unrecognized embedder type {args.embedder_type}")
122
 
 
5
  import time
6
 
7
  from chunker import UniversalChunker
8
+ from embedder import MarqoEmbedder, OpenAIBatchEmbedder
9
  from repo_manager import RepoManager
10
  from vector_store import PineconeVectorStore
11
 
12
  logging.basicConfig(level=logging.INFO)
13
 
14
  OPENAI_EMBEDDING_SIZE = 1536
15
+ MAX_TOKENS_PER_CHUNK = 8192 # The ADA embedder from OpenAI has a maximum of 8192 tokens.
16
+ MAX_CHUNKS_PER_BATCH = 2048 # The OpenAI batch embedding API enforces a maximum of 2048 chunks per batch.
 
 
 
 
17
  MAX_TOKENS_PER_JOB = 3_000_000 # The OpenAI batch embedding API enforces a maximum of 3M tokens processed at once.
18
 
19
 
 
39
  help="https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.",
40
  )
41
  parser.add_argument(
42
+ "--chunks_per_batch",
43
+ type=int,
44
+ default=2000,
45
+ help="Maximum chunks per batch. We recommend 2000 for the OpenAI embedder. Marqo enforces a limit of 64.",
46
  )
47
+ parser.add_argument("--index_name", required=True, help="Vector store index name")
48
  parser.add_argument(
49
  "--include",
50
  help="Path to a file containing a list of extensions to include. One extension per line.",
 
55
  help="Path to a file containing a list of extensions to exclude. One extension per line.",
56
  )
57
  parser.add_argument(
58
+ "--max_embedding_jobs",
59
+ type=int,
60
  help="Maximum number of embedding jobs to run. Specifying this might result in "
61
  "indexing only part of the repository, but prevents you from burning through OpenAI credits.",
62
  )
 
77
  parser.error("When using OpenAI embedder, the vector store type must be Pinecone.")
78
  if args.embedder_type == "marqo" and args.vector_store_type != "marqo":
79
  parser.error("When using the marqo embedder, the vector store type must also be marqo.")
80
+ if args.embedder_type == "marqo" and args.chunks_per_batch > 64:
81
+ args.chunks_per_batch = 64
82
+ logging.warning("Marqo enforces a limit of 64 chunks per batch. Setting --chunks_per_batch to 64.")
83
 
84
  # Validate other arguments.
85
  if args.tokens_per_chunk > MAX_TOKENS_PER_CHUNK:
86
+ parser.error(f"The maximum number of tokens per chunk is {MAX_TOKENS_PER_CHUNK}.")
 
 
87
  if args.chunks_per_batch > MAX_CHUNKS_PER_BATCH:
88
+ parser.error(f"The maximum number of chunks per batch is {MAX_CHUNKS_PER_BATCH}.")
 
 
89
  if args.tokens_per_chunk * args.chunks_per_batch >= MAX_TOKENS_PER_JOB:
90
  parser.error(f"The maximum number of chunks per job is {MAX_TOKENS_PER_JOB}.")
91
  if args.include and args.exclude:
 
109
  if args.embedder_type == "openai":
110
  embedder = OpenAIBatchEmbedder(repo_manager, chunker, args.local_dir)
111
  elif args.embedder_type == "marqo":
112
+ embedder = MarqoEmbedder(
113
+ repo_manager, chunker, index_name=args.index_name, url=args.marqo_url, model=args.marqo_embedding_model
114
+ )
 
 
115
  else:
116
  raise ValueError(f"Unrecognized embedder type {args.embedder_type}")
117
 
src/repo_manager.py CHANGED
@@ -35,9 +35,7 @@ class RepoManager:
35
  @cached_property
36
  def is_public(self) -> bool:
37
  """Checks whether a GitHub repository is publicly visible."""
38
- response = requests.get(
39
- f"https://api.github.com/repos/{self.repo_id}", timeout=10
40
- )
41
  # Note that the response will be 404 for both private and non-existent repos.
42
  return response.status_code == 200
43
 
@@ -50,17 +48,13 @@ class RepoManager:
50
  if self.access_token:
51
  headers["Authorization"] = f"token {self.access_token}"
52
 
53
- response = requests.get(
54
- f"https://api.github.com/repos/{self.repo_id}", headers=headers
55
- )
56
  if response.status_code == 200:
57
  branch = response.json().get("default_branch", "main")
58
  else:
59
  # This happens sometimes when we exceed the Github rate limit. The best bet in this case is to assume the
60
  # most common naming for the default branch ("main").
61
- logging.warn(
62
- f"Unable to fetch default branch for {self.repo_id}: {response.text}"
63
- )
64
  branch = "main"
65
  return branch
66
 
@@ -81,9 +75,7 @@ class RepoManager:
81
  try:
82
  Repo.clone_from(clone_url, self.local_path, depth=1, single_branch=True)
83
  except GitCommandError as e:
84
- logging.error(
85
- "Unable to clone %s from %s. Error: %s", self.repo_id, clone_url, e
86
- )
87
  return False
88
  return True
89
 
@@ -130,9 +122,7 @@ class RepoManager:
130
  for path in included_file_paths:
131
  f.write(path + "\n")
132
 
133
- excluded_file_paths = set(file_paths).difference(
134
- set(included_file_paths)
135
- )
136
  with open(excluded_log_file, "a") as f:
137
  for path in excluded_file_paths:
138
  f.write(path + "\n")
@@ -142,15 +132,11 @@ class RepoManager:
142
  try:
143
  contents = f.read()
144
  except UnicodeDecodeError:
145
- logging.warning(
146
- "Unable to decode file %s. Skipping.", file_path
147
- )
148
  continue
149
  yield file_path[len(self.local_dir) + 1 :], contents
150
 
151
  def github_link_for_file(self, file_path: str) -> str:
152
  """Converts a repository file path to a GitHub link."""
153
  file_path = file_path[len(self.repo_id) :]
154
- return (
155
- f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
156
- )
 
35
  @cached_property
36
  def is_public(self) -> bool:
37
  """Checks whether a GitHub repository is publicly visible."""
38
+ response = requests.get(f"https://api.github.com/repos/{self.repo_id}", timeout=10)
 
 
39
  # Note that the response will be 404 for both private and non-existent repos.
40
  return response.status_code == 200
41
 
 
48
  if self.access_token:
49
  headers["Authorization"] = f"token {self.access_token}"
50
 
51
+ response = requests.get(f"https://api.github.com/repos/{self.repo_id}", headers=headers)
 
 
52
  if response.status_code == 200:
53
  branch = response.json().get("default_branch", "main")
54
  else:
55
  # This happens sometimes when we exceed the Github rate limit. The best bet in this case is to assume the
56
  # most common naming for the default branch ("main").
57
+ logging.warn(f"Unable to fetch default branch for {self.repo_id}: {response.text}")
 
 
58
  branch = "main"
59
  return branch
60
 
 
75
  try:
76
  Repo.clone_from(clone_url, self.local_path, depth=1, single_branch=True)
77
  except GitCommandError as e:
78
+ logging.error("Unable to clone %s from %s. Error: %s", self.repo_id, clone_url, e)
 
 
79
  return False
80
  return True
81
 
 
122
  for path in included_file_paths:
123
  f.write(path + "\n")
124
 
125
+ excluded_file_paths = set(file_paths).difference(set(included_file_paths))
 
 
126
  with open(excluded_log_file, "a") as f:
127
  for path in excluded_file_paths:
128
  f.write(path + "\n")
 
132
  try:
133
  contents = f.read()
134
  except UnicodeDecodeError:
135
+ logging.warning("Unable to decode file %s. Skipping.", file_path)
 
 
136
  continue
137
  yield file_path[len(self.local_dir) + 1 :], contents
138
 
139
  def github_link_for_file(self, file_path: str) -> str:
140
  """Converts a repository file path to a GitHub link."""
141
  file_path = file_path[len(self.repo_id) :]
142
+ return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
 
 
src/vector_store.py CHANGED
@@ -10,6 +10,7 @@ Vector = Tuple[Dict, List[float]] # (metadata, embedding)
10
 
11
  class VectorStore(ABC):
12
  """Abstract class for a vector store."""
 
13
  @abstractmethod
14
  def ensure_exists(self):
15
  """Ensures that the vector store exists. Creates it if it doesn't."""
@@ -42,13 +43,10 @@ class PineconeVectorStore(VectorStore):
42
 
43
  def ensure_exists(self):
44
  if self.index_name not in self.client.list_indexes().names():
45
- self.client.create_index(
46
- name=self.index_name, dimension=self.dimension, metric="cosine"
47
- )
48
 
49
  def upsert_batch(self, vectors: List[Vector]):
50
  pinecone_vectors = [
51
- (metadata.get("id", str(i)), embedding, metadata)
52
- for i, (metadata, embedding) in enumerate(vectors)
53
  ]
54
  self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
 
10
 
11
  class VectorStore(ABC):
12
  """Abstract class for a vector store."""
13
+
14
  @abstractmethod
15
  def ensure_exists(self):
16
  """Ensures that the vector store exists. Creates it if it doesn't."""
 
43
 
44
  def ensure_exists(self):
45
  if self.index_name not in self.client.list_indexes().names():
46
+ self.client.create_index(name=self.index_name, dimension=self.dimension, metric="cosine")
 
 
47
 
48
  def upsert_batch(self, vectors: List[Vector]):
49
  pinecone_vectors = [
50
+ (metadata.get("id", str(i)), embedding, metadata) for i, (metadata, embedding) in enumerate(vectors)
 
51
  ]
52
  self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)