Spaces:
Running
Running
Index GitHub Issues (#21)
Browse files* Generalize RepoManager into DataManager
* Add chunker for GitHub issues
* Update README, fix flags.
- README.md +21 -14
- src/chat.py +10 -14
- src/chunker.py +93 -72
- src/{repo_manager.py → data_manager.py} +49 -25
- src/embedder.py +50 -34
- src/github.py +226 -0
- src/index.py +91 -52
README.md
CHANGED
|
@@ -38,9 +38,9 @@ We currently support two options for indexing the codebase:
|
|
| 38 |
|
| 39 |
python src/index.py
|
| 40 |
github-repo-name \ # e.g. Storia-AI/repo2vec
|
| 41 |
-
--
|
| 42 |
-
--
|
| 43 |
-
--
|
| 44 |
```
|
| 45 |
|
| 46 |
2. **Using external providers** (OpenAI for embeddings and [Pinecone](https://www.pinecone.io/) for the vector store). To index your codebase, run:
|
|
@@ -52,12 +52,15 @@ We currently support two options for indexing the codebase:
|
|
| 52 |
|
| 53 |
python src/index.py
|
| 54 |
github-repo-name \ # e.g. Storia-AI/repo2vec
|
| 55 |
-
--
|
| 56 |
-
--
|
| 57 |
-
--
|
| 58 |
```
|
| 59 |
We are planning on adding more providers soon, so that you can mix and match them. Contributions are also welcome!
|
| 60 |
|
|
|
|
|
|
|
|
|
|
| 61 |
## Chatting with the codebase
|
| 62 |
We provide a `gradio` app where you can chat with your codebase. You can use either a local LLM (via [Ollama](https://ollama.com)), or a cloud provider like OpenAI or Anthropic.
|
| 63 |
|
|
@@ -68,10 +71,10 @@ To chat with a local LLM:
|
|
| 68 |
```
|
| 69 |
python src/chat.py \
|
| 70 |
github-repo-name \ # e.g. Storia-AI/repo2vec
|
| 71 |
-
--
|
| 72 |
-
--
|
| 73 |
-
--
|
| 74 |
-
--
|
| 75 |
```
|
| 76 |
|
| 77 |
To chat with a cloud-based LLM, for instance Anthropic's Claude:
|
|
@@ -80,10 +83,10 @@ export ANTHROPIC_API_KEY=...
|
|
| 80 |
|
| 81 |
python src/chat.py \
|
| 82 |
github-repo-name \ # e.g. Storia-AI/repo2vec
|
| 83 |
-
--
|
| 84 |
-
--
|
| 85 |
-
--
|
| 86 |
-
--
|
| 87 |
```
|
| 88 |
To get a public URL for your chat app, set `--share=true`.
|
| 89 |
|
|
@@ -121,6 +124,10 @@ The `src/chat.py` brings up a [Gradio app](https://www.gradio.app/) with a chat
|
|
| 121 |
|
| 122 |
The sources are conveniently surfaced in the chat and linked directly to GitHub.
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
# Want your repository hosted?
|
| 125 |
|
| 126 |
We're working to make all code on the internet searchable and understandable for devs. You can check out our early product, [Code Sage](https://sage.storia.ai). We pre-indexed a slew of OSS repos, and you can index your desired ones by simply pasting a GitHub URL.
|
|
|
|
| 38 |
|
| 39 |
python src/index.py
|
| 40 |
github-repo-name \ # e.g. Storia-AI/repo2vec
|
| 41 |
+
--embedder-type=marqo \
|
| 42 |
+
--vector-store-type=marqo \
|
| 43 |
+
--index-name=your-index-name
|
| 44 |
```
|
| 45 |
|
| 46 |
2. **Using external providers** (OpenAI for embeddings and [Pinecone](https://www.pinecone.io/) for the vector store). To index your codebase, run:
|
|
|
|
| 52 |
|
| 53 |
python src/index.py
|
| 54 |
github-repo-name \ # e.g. Storia-AI/repo2vec
|
| 55 |
+
--embedder-type=openai \
|
| 56 |
+
--vector-store-type=pinecone \
|
| 57 |
+
--index-name=your-index-name
|
| 58 |
```
|
| 59 |
We are planning on adding more providers soon, so that you can mix and match them. Contributions are also welcome!
|
| 60 |
|
| 61 |
+
## Indexing GitHub Issues
|
| 62 |
+
By default, we also index the open GitHub issues associated with a codebase. You can control what gets index with the `--index-repo` and `--index-issues` flags (and their converse `--no-index-repo` and `--no-index-issues`).
|
| 63 |
+
|
| 64 |
## Chatting with the codebase
|
| 65 |
We provide a `gradio` app where you can chat with your codebase. You can use either a local LLM (via [Ollama](https://ollama.com)), or a cloud provider like OpenAI or Anthropic.
|
| 66 |
|
|
|
|
| 71 |
```
|
| 72 |
python src/chat.py \
|
| 73 |
github-repo-name \ # e.g. Storia-AI/repo2vec
|
| 74 |
+
--llm-provider=ollama
|
| 75 |
+
--llm-model=llama3.1
|
| 76 |
+
--vector-store-type=marqo \ # or pinecone
|
| 77 |
+
--index-name=your-index-name
|
| 78 |
```
|
| 79 |
|
| 80 |
To chat with a cloud-based LLM, for instance Anthropic's Claude:
|
|
|
|
| 83 |
|
| 84 |
python src/chat.py \
|
| 85 |
github-repo-name \ # e.g. Storia-AI/repo2vec
|
| 86 |
+
--llm-provider=anthropic \
|
| 87 |
+
--llm-model=claude-3-opus-20240229 \
|
| 88 |
+
--vector-store-type=marqo \ # or pinecone
|
| 89 |
+
--index-name=your-index-name
|
| 90 |
```
|
| 91 |
To get a public URL for your chat app, set `--share=true`.
|
| 92 |
|
|
|
|
| 124 |
|
| 125 |
The sources are conveniently surfaced in the chat and linked directly to GitHub.
|
| 126 |
|
| 127 |
+
# Changelog
|
| 128 |
+
- 2024-09-03: Support for indexing GitHub issues.
|
| 129 |
+
- 2024-08-30: Support for running everything locally (Marqo for embeddings, Ollama for LLMs).
|
| 130 |
+
|
| 131 |
# Want your repository hosted?
|
| 132 |
|
| 133 |
We're working to make all code on the internet searchable and understandable for devs. You can check out our early product, [Code Sage](https://sage.storia.ai). We pre-indexed a slew of OSS repos, and you can index your desired ones by simply pasting a GitHub URL.
|
src/chat.py
CHANGED
|
@@ -7,15 +7,13 @@ import argparse
|
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
from dotenv import load_dotenv
|
| 10 |
-
from langchain.chains import
|
| 11 |
-
create_retrieval_chain)
|
| 12 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 13 |
from langchain.schema import AIMessage, HumanMessage
|
| 14 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 15 |
|
| 16 |
import vector_store
|
| 17 |
from llm import build_llm_via_langchain
|
| 18 |
-
from repo_manager import RepoManager
|
| 19 |
|
| 20 |
load_dotenv()
|
| 21 |
|
|
@@ -63,26 +61,24 @@ def build_rag_chain(args):
|
|
| 63 |
|
| 64 |
def append_sources_to_response(response):
|
| 65 |
"""Given an OpenAI completion response, appends to it GitHub links of the context sources."""
|
| 66 |
-
|
| 67 |
-
# Deduplicate
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
github_links = [repo_manager.github_link_for_file(filename) for filename in filenames]
|
| 71 |
-
return response["answer"] + "\n\nSources:\n" + "\n".join(github_links)
|
| 72 |
|
| 73 |
|
| 74 |
if __name__ == "__main__":
|
| 75 |
parser = argparse.ArgumentParser(description="UI to chat with your codebase")
|
| 76 |
parser.add_argument("repo_id", help="The ID of the repository to index")
|
| 77 |
-
parser.add_argument("--
|
| 78 |
parser.add_argument(
|
| 79 |
-
"--
|
| 80 |
help="The LLM name. Must be supported by the provider specified via --llm_provider.",
|
| 81 |
)
|
| 82 |
-
parser.add_argument("--
|
| 83 |
-
parser.add_argument("--
|
| 84 |
parser.add_argument(
|
| 85 |
-
"--
|
| 86 |
default="http://localhost:8882",
|
| 87 |
help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
|
| 88 |
)
|
|
|
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
from dotenv import load_dotenv
|
| 10 |
+
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
|
|
|
| 11 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 12 |
from langchain.schema import AIMessage, HumanMessage
|
| 13 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 14 |
|
| 15 |
import vector_store
|
| 16 |
from llm import build_llm_via_langchain
|
|
|
|
| 17 |
|
| 18 |
load_dotenv()
|
| 19 |
|
|
|
|
| 61 |
|
| 62 |
def append_sources_to_response(response):
|
| 63 |
"""Given an OpenAI completion response, appends to it GitHub links of the context sources."""
|
| 64 |
+
urls = [document.metadata["url"] for document in response["context"]]
|
| 65 |
+
# Deduplicate urls while preserving their order.
|
| 66 |
+
urls = list(dict.fromkeys(urls))
|
| 67 |
+
return response["answer"] + "\n\nSources:\n" + "\n".join(urls)
|
|
|
|
|
|
|
| 68 |
|
| 69 |
|
| 70 |
if __name__ == "__main__":
|
| 71 |
parser = argparse.ArgumentParser(description="UI to chat with your codebase")
|
| 72 |
parser.add_argument("repo_id", help="The ID of the repository to index")
|
| 73 |
+
parser.add_argument("--llm-provider", default="anthropic", choices=["openai", "anthropic", "ollama"])
|
| 74 |
parser.add_argument(
|
| 75 |
+
"--llm-model",
|
| 76 |
help="The LLM name. Must be supported by the provider specified via --llm_provider.",
|
| 77 |
)
|
| 78 |
+
parser.add_argument("--vector-store-type", default="pinecone", choices=["pinecone", "marqo"])
|
| 79 |
+
parser.add_argument("--index-name", required=True, help="Vector store index name")
|
| 80 |
parser.add_argument(
|
| 81 |
+
"--marqo-url",
|
| 82 |
default="http://localhost:8882",
|
| 83 |
help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
|
| 84 |
)
|
src/chunker.py
CHANGED
|
@@ -3,8 +3,8 @@
|
|
| 3 |
import logging
|
| 4 |
from abc import ABC, abstractmethod
|
| 5 |
from dataclasses import dataclass
|
| 6 |
-
from functools import
|
| 7 |
-
from typing import List, Optional
|
| 8 |
|
| 9 |
import nbformat
|
| 10 |
import pygments
|
|
@@ -14,31 +14,47 @@ from tree_sitter import Node
|
|
| 14 |
from tree_sitter_language_pack import get_parser
|
| 15 |
|
| 16 |
logger = logging.getLogger(__name__)
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
-
@dataclass
|
| 20 |
class Chunk:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
"""A chunk of code or text extracted from a file in the repository."""
|
| 22 |
|
| 23 |
-
|
|
|
|
| 24 |
start_byte: int
|
| 25 |
end_byte: int
|
| 26 |
-
_content: Optional[str] = None
|
| 27 |
|
| 28 |
-
@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
def content(self) -> Optional[str]:
|
| 30 |
"""The text content to be embedded. Might contain information beyond just the text snippet from the file."""
|
| 31 |
-
return self.
|
| 32 |
|
| 33 |
-
@
|
| 34 |
-
def
|
| 35 |
"""Converts the chunk to a dictionary that can be passed to a vector store."""
|
| 36 |
# Some vector stores require the IDs to be ASCII.
|
| 37 |
filename_ascii = self.filename.encode("ascii", "ignore").decode("ascii")
|
| 38 |
-
|
| 39 |
# Some vector stores require the IDs to be ASCII.
|
| 40 |
"id": f"{filename_ascii}_{self.start_byte}_{self.end_byte}",
|
| 41 |
-
"filename": self.filename,
|
| 42 |
"start_byte": self.start_byte,
|
| 43 |
"end_byte": self.end_byte,
|
| 44 |
# Note to developer: When choosing a large chunk size, you might exceed the vector store's metadata
|
|
@@ -46,22 +62,13 @@ class Chunk:
|
|
| 46 |
# directly from the repository when needed.
|
| 47 |
"text": self.content,
|
| 48 |
}
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def num_tokens(self, tokenizer):
|
| 55 |
-
"""Counts the number of tokens in the chunk."""
|
| 56 |
-
if not self.content:
|
| 57 |
-
raise ValueError("Content not populated.")
|
| 58 |
-
return Chunk._cached_num_tokens(self.content, tokenizer)
|
| 59 |
-
|
| 60 |
-
@staticmethod
|
| 61 |
-
@lru_cache(maxsize=1024)
|
| 62 |
-
def _cached_num_tokens(content: str, tokenizer):
|
| 63 |
-
"""Static method to cache token counts."""
|
| 64 |
-
return len(tokenizer.encode(content, disallowed_special=()))
|
| 65 |
|
| 66 |
def __eq__(self, other):
|
| 67 |
if isinstance(other, Chunk):
|
|
@@ -77,20 +84,19 @@ class Chunk:
|
|
| 77 |
|
| 78 |
|
| 79 |
class Chunker(ABC):
|
| 80 |
-
"""Abstract class for chunking a
|
| 81 |
|
| 82 |
@abstractmethod
|
| 83 |
-
def chunk(self,
|
| 84 |
-
"""Chunks a
|
| 85 |
|
| 86 |
|
| 87 |
-
class
|
| 88 |
"""Splits a code file into chunks of at most `max_tokens` tokens each."""
|
| 89 |
|
| 90 |
def __init__(self, max_tokens: int):
|
| 91 |
self.max_tokens = max_tokens
|
| 92 |
-
self.
|
| 93 |
-
self.text_chunker = TextChunker(max_tokens)
|
| 94 |
|
| 95 |
@staticmethod
|
| 96 |
def _get_language_from_filename(filename: str):
|
|
@@ -103,25 +109,24 @@ class CodeChunker(Chunker):
|
|
| 103 |
except pygments.util.ClassNotFound:
|
| 104 |
return None
|
| 105 |
|
| 106 |
-
def _chunk_node(self, node: Node,
|
| 107 |
"""Splits a node in the parse tree into a flat list of chunks."""
|
| 108 |
-
node_chunk =
|
| 109 |
-
node_chunk.populate_content(file_content)
|
| 110 |
|
| 111 |
-
if node_chunk.num_tokens
|
| 112 |
return [node_chunk]
|
| 113 |
|
| 114 |
if not node.children:
|
| 115 |
# This is a leaf node, but it's too long. We'll have to split it with a text tokenizer.
|
| 116 |
-
return self.text_chunker.chunk(
|
| 117 |
|
| 118 |
chunks = []
|
| 119 |
for child in node.children:
|
| 120 |
-
chunks.extend(self._chunk_node(child,
|
| 121 |
|
| 122 |
for chunk in chunks:
|
| 123 |
# This should always be true. Otherwise there must be a bug in the code.
|
| 124 |
-
assert chunk.
|
| 125 |
|
| 126 |
# Merge neighboring chunks if their combined size doesn't exceed max_tokens. The goal is to avoid pathologically
|
| 127 |
# small chunks that end up being undeservedly preferred by the retriever.
|
|
@@ -129,16 +134,16 @@ class CodeChunker(Chunker):
|
|
| 129 |
for chunk in chunks:
|
| 130 |
if not merged_chunks:
|
| 131 |
merged_chunks.append(chunk)
|
| 132 |
-
elif merged_chunks[-1].num_tokens
|
| 133 |
# There's a good chance that merging these two chunks will be under the token limit. We're not 100% sure
|
| 134 |
# at this point, because tokenization is not necessarily additive.
|
| 135 |
-
merged =
|
| 136 |
-
|
|
|
|
| 137 |
merged_chunks[-1].start_byte,
|
| 138 |
chunk.end_byte,
|
| 139 |
)
|
| 140 |
-
merged.
|
| 141 |
-
if merged.num_tokens(self.tokenizer) <= self.max_tokens:
|
| 142 |
merged_chunks[-1] = merged
|
| 143 |
else:
|
| 144 |
merged_chunks.append(chunk)
|
|
@@ -148,20 +153,20 @@ class CodeChunker(Chunker):
|
|
| 148 |
|
| 149 |
for chunk in merged_chunks:
|
| 150 |
# This should always be true. Otherwise there's a bug worth investigating.
|
| 151 |
-
assert chunk.
|
| 152 |
|
| 153 |
return merged_chunks
|
| 154 |
|
| 155 |
@staticmethod
|
| 156 |
def is_code_file(filename: str) -> bool:
|
| 157 |
"""Checks whether pygment & tree_sitter can parse the file as code."""
|
| 158 |
-
language =
|
| 159 |
return language and language not in ["text only", "None"]
|
| 160 |
|
| 161 |
@staticmethod
|
| 162 |
def parse_tree(filename: str, content: str) -> List[str]:
|
| 163 |
"""Parses the code in a file and returns the parse tree."""
|
| 164 |
-
language =
|
| 165 |
|
| 166 |
if not language or language in ["text only", "None"]:
|
| 167 |
logging.debug("%s doesn't seem to be a code file.", filename)
|
|
@@ -180,8 +185,12 @@ class CodeChunker(Chunker):
|
|
| 180 |
return None
|
| 181 |
return tree
|
| 182 |
|
| 183 |
-
def chunk(self,
|
| 184 |
"""Chunks a code file into smaller pieces."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
if not file_content.strip():
|
| 186 |
return []
|
| 187 |
|
|
@@ -189,33 +198,33 @@ class CodeChunker(Chunker):
|
|
| 189 |
if tree is None:
|
| 190 |
return []
|
| 191 |
|
| 192 |
-
|
| 193 |
-
for chunk in
|
| 194 |
# Make sure that the chunk has content and doesn't exceed the max_tokens limit. Otherwise there must be
|
| 195 |
# a bug in the code.
|
| 196 |
-
assert chunk.
|
| 197 |
-
size = chunk.num_tokens(self.tokenizer)
|
| 198 |
-
assert size <= self.max_tokens, f"Chunk size {size} exceeds max_tokens {self.max_tokens}."
|
| 199 |
|
| 200 |
-
return
|
| 201 |
|
| 202 |
|
| 203 |
-
class
|
| 204 |
"""Wrapper around semchunk: https://github.com/umarbutler/semchunk."""
|
| 205 |
|
| 206 |
def __init__(self, max_tokens: int):
|
| 207 |
self.max_tokens = max_tokens
|
| 208 |
-
|
| 209 |
-
tokenizer = tiktoken.get_encoding("cl100k_base")
|
| 210 |
self.count_tokens = lambda text: len(tokenizer.encode(text, disallowed_special=()))
|
| 211 |
|
| 212 |
-
def chunk(self,
|
| 213 |
"""Chunks a text file into smaller pieces."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
# We need to allocate some tokens for the filename, which is part of the chunk content.
|
| 215 |
extra_tokens = self.count_tokens(file_path + "\n\n")
|
| 216 |
text_chunks = chunk_via_semchunk(file_content, self.max_tokens - extra_tokens, self.count_tokens)
|
| 217 |
|
| 218 |
-
|
| 219 |
start = 0
|
| 220 |
for text_chunk in text_chunks:
|
| 221 |
# This assertion should always be true. Otherwise there's a bug worth finding.
|
|
@@ -227,22 +236,25 @@ class TextChunker(Chunker):
|
|
| 227 |
logging.warning("Couldn't find semchunk in content: %s", text_chunk)
|
| 228 |
else:
|
| 229 |
end = start + len(text_chunk)
|
| 230 |
-
|
| 231 |
|
| 232 |
start = end
|
| 233 |
-
|
|
|
|
| 234 |
|
| 235 |
|
| 236 |
-
class
|
| 237 |
"""Extracts the python code from a Jupyter notebook, removing all the boilerplate.
|
| 238 |
|
| 239 |
Based on https://github.com/GoogleCloudPlatform/generative-ai/blob/main/language/code/code_retrieval_augmented_generation.ipynb
|
| 240 |
"""
|
| 241 |
|
| 242 |
-
def __init__(self, code_chunker:
|
| 243 |
self.code_chunker = code_chunker
|
| 244 |
|
| 245 |
-
def chunk(self,
|
|
|
|
|
|
|
| 246 |
if not filename.lower().endswith(".ipynb"):
|
| 247 |
logging.warn("IPYNBChunker is only for .ipynb files.")
|
| 248 |
return []
|
|
@@ -256,16 +268,25 @@ class IPYNBChunker(Chunker):
|
|
| 256 |
return chunks
|
| 257 |
|
| 258 |
|
| 259 |
-
class
|
| 260 |
"""Chunks a file into smaller pieces, regardless of whether it's code or text."""
|
| 261 |
|
| 262 |
def __init__(self, max_tokens: int):
|
| 263 |
-
self.code_chunker =
|
| 264 |
-
self.
|
|
|
|
| 265 |
|
| 266 |
-
def chunk(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
if file_path.lower().endswith(".ipynb"):
|
| 268 |
-
|
| 269 |
-
if
|
| 270 |
-
|
| 271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import logging
|
| 4 |
from abc import ABC, abstractmethod
|
| 5 |
from dataclasses import dataclass
|
| 6 |
+
from functools import cached_property
|
| 7 |
+
from typing import Any, Dict, List, Optional
|
| 8 |
|
| 9 |
import nbformat
|
| 10 |
import pygments
|
|
|
|
| 14 |
from tree_sitter_language_pack import get_parser
|
| 15 |
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
+
tokenizer = tiktoken.get_encoding("cl100k_base")
|
| 18 |
|
| 19 |
|
|
|
|
| 20 |
class Chunk:
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def content(self) -> str:
|
| 23 |
+
"""The content of the chunk to be indexed."""
|
| 24 |
+
|
| 25 |
+
@abstractmethod
|
| 26 |
+
def metadata(self) -> Dict:
|
| 27 |
+
"""Metadata for the chunk to be indexed."""
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class FileChunk(Chunk):
|
| 32 |
"""A chunk of code or text extracted from a file in the repository."""
|
| 33 |
|
| 34 |
+
file_content: str # The content of the entire file, not just this chunk.
|
| 35 |
+
file_metadata: Dict # Metadata of the entire file, not just this chunk.
|
| 36 |
start_byte: int
|
| 37 |
end_byte: int
|
|
|
|
| 38 |
|
| 39 |
+
@cached_property
|
| 40 |
+
def filename(self):
|
| 41 |
+
if not "file_path" in self.file_metadata:
|
| 42 |
+
raise ValueError("file_metadata must contain a 'file_path' key.")
|
| 43 |
+
return self.file_metadata["file_path"]
|
| 44 |
+
|
| 45 |
+
@cached_property
|
| 46 |
def content(self) -> Optional[str]:
|
| 47 |
"""The text content to be embedded. Might contain information beyond just the text snippet from the file."""
|
| 48 |
+
return self.filename + "\n\n" + self.file_content[self.start_byte : self.end_byte]
|
| 49 |
|
| 50 |
+
@cached_property
|
| 51 |
+
def metadata(self):
|
| 52 |
"""Converts the chunk to a dictionary that can be passed to a vector store."""
|
| 53 |
# Some vector stores require the IDs to be ASCII.
|
| 54 |
filename_ascii = self.filename.encode("ascii", "ignore").decode("ascii")
|
| 55 |
+
chunk_metadata = {
|
| 56 |
# Some vector stores require the IDs to be ASCII.
|
| 57 |
"id": f"{filename_ascii}_{self.start_byte}_{self.end_byte}",
|
|
|
|
| 58 |
"start_byte": self.start_byte,
|
| 59 |
"end_byte": self.end_byte,
|
| 60 |
# Note to developer: When choosing a large chunk size, you might exceed the vector store's metadata
|
|
|
|
| 62 |
# directly from the repository when needed.
|
| 63 |
"text": self.content,
|
| 64 |
}
|
| 65 |
+
chunk_metadata.update(self.file_metadata)
|
| 66 |
+
return chunk_metadata
|
| 67 |
|
| 68 |
+
@cached_property
|
| 69 |
+
def num_tokens(self):
|
| 70 |
+
"""Number of tokens in this chunk."""
|
| 71 |
+
return len(tokenizer.encode(self.content, disallowed_special=()))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
def __eq__(self, other):
|
| 74 |
if isinstance(other, Chunk):
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
class Chunker(ABC):
|
| 87 |
+
"""Abstract class for chunking a datum into smaller pieces."""
|
| 88 |
|
| 89 |
@abstractmethod
|
| 90 |
+
def chunk(self, content: Any, metadata: Dict) -> List[Chunk]:
|
| 91 |
+
"""Chunks a datum into smaller pieces."""
|
| 92 |
|
| 93 |
|
| 94 |
+
class CodeFileChunker(Chunker):
|
| 95 |
"""Splits a code file into chunks of at most `max_tokens` tokens each."""
|
| 96 |
|
| 97 |
def __init__(self, max_tokens: int):
|
| 98 |
self.max_tokens = max_tokens
|
| 99 |
+
self.text_chunker = TextFileChunker(max_tokens)
|
|
|
|
| 100 |
|
| 101 |
@staticmethod
|
| 102 |
def _get_language_from_filename(filename: str):
|
|
|
|
| 109 |
except pygments.util.ClassNotFound:
|
| 110 |
return None
|
| 111 |
|
| 112 |
+
def _chunk_node(self, node: Node, file_content: str, file_metadata: Dict) -> List[FileChunk]:
|
| 113 |
"""Splits a node in the parse tree into a flat list of chunks."""
|
| 114 |
+
node_chunk = FileChunk(file_content, file_metadata, node.start_byte, node.end_byte)
|
|
|
|
| 115 |
|
| 116 |
+
if node_chunk.num_tokens <= self.max_tokens:
|
| 117 |
return [node_chunk]
|
| 118 |
|
| 119 |
if not node.children:
|
| 120 |
# This is a leaf node, but it's too long. We'll have to split it with a text tokenizer.
|
| 121 |
+
return self.text_chunker.chunk(file_content[node.start_byte : node.end_byte], file_metadata)
|
| 122 |
|
| 123 |
chunks = []
|
| 124 |
for child in node.children:
|
| 125 |
+
chunks.extend(self._chunk_node(child, file_content, file_metadata))
|
| 126 |
|
| 127 |
for chunk in chunks:
|
| 128 |
# This should always be true. Otherwise there must be a bug in the code.
|
| 129 |
+
assert chunk.num_tokens <= self.max_tokens
|
| 130 |
|
| 131 |
# Merge neighboring chunks if their combined size doesn't exceed max_tokens. The goal is to avoid pathologically
|
| 132 |
# small chunks that end up being undeservedly preferred by the retriever.
|
|
|
|
| 134 |
for chunk in chunks:
|
| 135 |
if not merged_chunks:
|
| 136 |
merged_chunks.append(chunk)
|
| 137 |
+
elif merged_chunks[-1].num_tokens + chunk.num_tokens < self.max_tokens - 50:
|
| 138 |
# There's a good chance that merging these two chunks will be under the token limit. We're not 100% sure
|
| 139 |
# at this point, because tokenization is not necessarily additive.
|
| 140 |
+
merged = FileChunk(
|
| 141 |
+
file_content,
|
| 142 |
+
file_metadata,
|
| 143 |
merged_chunks[-1].start_byte,
|
| 144 |
chunk.end_byte,
|
| 145 |
)
|
| 146 |
+
if merged.num_tokens <= self.max_tokens:
|
|
|
|
| 147 |
merged_chunks[-1] = merged
|
| 148 |
else:
|
| 149 |
merged_chunks.append(chunk)
|
|
|
|
| 153 |
|
| 154 |
for chunk in merged_chunks:
|
| 155 |
# This should always be true. Otherwise there's a bug worth investigating.
|
| 156 |
+
assert chunk.num_tokens <= self.max_tokens
|
| 157 |
|
| 158 |
return merged_chunks
|
| 159 |
|
| 160 |
@staticmethod
|
| 161 |
def is_code_file(filename: str) -> bool:
|
| 162 |
"""Checks whether pygment & tree_sitter can parse the file as code."""
|
| 163 |
+
language = CodeFileChunker._get_language_from_filename(filename)
|
| 164 |
return language and language not in ["text only", "None"]
|
| 165 |
|
| 166 |
@staticmethod
|
| 167 |
def parse_tree(filename: str, content: str) -> List[str]:
|
| 168 |
"""Parses the code in a file and returns the parse tree."""
|
| 169 |
+
language = CodeFileChunker._get_language_from_filename(filename)
|
| 170 |
|
| 171 |
if not language or language in ["text only", "None"]:
|
| 172 |
logging.debug("%s doesn't seem to be a code file.", filename)
|
|
|
|
| 185 |
return None
|
| 186 |
return tree
|
| 187 |
|
| 188 |
+
def chunk(self, content: Any, metadata: Dict) -> List[Chunk]:
|
| 189 |
"""Chunks a code file into smaller pieces."""
|
| 190 |
+
file_content = content
|
| 191 |
+
file_metadata = metadata
|
| 192 |
+
file_path = metadata["file_path"]
|
| 193 |
+
|
| 194 |
if not file_content.strip():
|
| 195 |
return []
|
| 196 |
|
|
|
|
| 198 |
if tree is None:
|
| 199 |
return []
|
| 200 |
|
| 201 |
+
file_chunks = self._chunk_node(tree.root_node, file_content, file_metadata)
|
| 202 |
+
for chunk in file_chunks:
|
| 203 |
# Make sure that the chunk has content and doesn't exceed the max_tokens limit. Otherwise there must be
|
| 204 |
# a bug in the code.
|
| 205 |
+
assert chunk.num_tokens <= self.max_tokens, f"Chunk size {chunk.num_tokens} exceeds max_tokens {self.max_tokens}."
|
|
|
|
|
|
|
| 206 |
|
| 207 |
+
return file_chunks
|
| 208 |
|
| 209 |
|
| 210 |
+
class TextFileChunker(Chunker):
|
| 211 |
"""Wrapper around semchunk: https://github.com/umarbutler/semchunk."""
|
| 212 |
|
| 213 |
def __init__(self, max_tokens: int):
|
| 214 |
self.max_tokens = max_tokens
|
|
|
|
|
|
|
| 215 |
self.count_tokens = lambda text: len(tokenizer.encode(text, disallowed_special=()))
|
| 216 |
|
| 217 |
+
def chunk(self, content: Any, metadata: Dict) -> List[Chunk]:
|
| 218 |
"""Chunks a text file into smaller pieces."""
|
| 219 |
+
file_content = content
|
| 220 |
+
file_metadata = metadata
|
| 221 |
+
file_path = file_metadata["file_path"]
|
| 222 |
+
|
| 223 |
# We need to allocate some tokens for the filename, which is part of the chunk content.
|
| 224 |
extra_tokens = self.count_tokens(file_path + "\n\n")
|
| 225 |
text_chunks = chunk_via_semchunk(file_content, self.max_tokens - extra_tokens, self.count_tokens)
|
| 226 |
|
| 227 |
+
file_chunks = []
|
| 228 |
start = 0
|
| 229 |
for text_chunk in text_chunks:
|
| 230 |
# This assertion should always be true. Otherwise there's a bug worth finding.
|
|
|
|
| 236 |
logging.warning("Couldn't find semchunk in content: %s", text_chunk)
|
| 237 |
else:
|
| 238 |
end = start + len(text_chunk)
|
| 239 |
+
file_chunks.append(FileChunk(file_content, file_metadata, start, end))
|
| 240 |
|
| 241 |
start = end
|
| 242 |
+
|
| 243 |
+
return file_chunks
|
| 244 |
|
| 245 |
|
| 246 |
+
class IpynbFileChunker(Chunker):
|
| 247 |
"""Extracts the python code from a Jupyter notebook, removing all the boilerplate.
|
| 248 |
|
| 249 |
Based on https://github.com/GoogleCloudPlatform/generative-ai/blob/main/language/code/code_retrieval_augmented_generation.ipynb
|
| 250 |
"""
|
| 251 |
|
| 252 |
+
def __init__(self, code_chunker: CodeFileChunker):
|
| 253 |
self.code_chunker = code_chunker
|
| 254 |
|
| 255 |
+
def chunk(self, content: Any, metadata: Dict) -> List[Chunk]:
|
| 256 |
+
filename = metadata["file_path"]
|
| 257 |
+
|
| 258 |
if not filename.lower().endswith(".ipynb"):
|
| 259 |
logging.warn("IPYNBChunker is only for .ipynb files.")
|
| 260 |
return []
|
|
|
|
| 268 |
return chunks
|
| 269 |
|
| 270 |
|
| 271 |
+
class UniversalFileChunker(Chunker):
|
| 272 |
"""Chunks a file into smaller pieces, regardless of whether it's code or text."""
|
| 273 |
|
| 274 |
def __init__(self, max_tokens: int):
|
| 275 |
+
self.code_chunker = CodeFileChunker(max_tokens)
|
| 276 |
+
self.ipynb_chunker = IpynbFileChunker(self.code_chunker)
|
| 277 |
+
self.text_chunker = TextFileChunker(max_tokens)
|
| 278 |
|
| 279 |
+
def chunk(self, content: Any, metadata: Dict) -> List[Chunk]:
|
| 280 |
+
if not "file_path" in metadata:
|
| 281 |
+
raise ValueError("metadata must contain a 'file_path' key.")
|
| 282 |
+
file_path = metadata["file_path"]
|
| 283 |
+
|
| 284 |
+
# Figure out the appropriate chunker to use.
|
| 285 |
if file_path.lower().endswith(".ipynb"):
|
| 286 |
+
chunker = self.ipynb_chunker
|
| 287 |
+
if CodeFileChunker.is_code_file(file_path):
|
| 288 |
+
chunker = self.code_chunker
|
| 289 |
+
else:
|
| 290 |
+
chunker = self.text_chunker
|
| 291 |
+
|
| 292 |
+
return chunker.chunk(content, metadata)
|
src/{repo_manager.py → data_manager.py}
RENAMED
|
@@ -2,13 +2,28 @@
|
|
| 2 |
|
| 3 |
import logging
|
| 4 |
import os
|
|
|
|
| 5 |
from functools import cached_property
|
|
|
|
| 6 |
|
| 7 |
import requests
|
| 8 |
from git import GitCommandError, Repo
|
| 9 |
|
| 10 |
|
| 11 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"""Class to manage a local clone of a GitHub repository."""
|
| 13 |
|
| 14 |
def __init__(
|
|
@@ -23,11 +38,18 @@ class RepoManager:
|
|
| 23 |
repo_id: The identifier of the repository in owner/repo format, e.g. "Storia-AI/repo2vec".
|
| 24 |
local_dir: The local directory where the repository will be cloned.
|
| 25 |
"""
|
|
|
|
| 26 |
self.repo_id = repo_id
|
|
|
|
| 27 |
self.local_dir = local_dir or "/tmp/"
|
| 28 |
if not os.path.exists(self.local_dir):
|
| 29 |
os.makedirs(self.local_dir)
|
| 30 |
self.local_path = os.path.join(self.local_dir, repo_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
self.access_token = os.getenv("GITHUB_TOKEN")
|
| 32 |
self.included_extensions = included_extensions
|
| 33 |
self.excluded_extensions = excluded_extensions
|
|
@@ -58,7 +80,7 @@ class RepoManager:
|
|
| 58 |
branch = "main"
|
| 59 |
return branch
|
| 60 |
|
| 61 |
-
def
|
| 62 |
"""Clones the repository to the local directory, if it's not already cloned."""
|
| 63 |
if os.path.exists(self.local_path):
|
| 64 |
# The repository is already cloned.
|
|
@@ -94,38 +116,35 @@ class RepoManager:
|
|
| 94 |
return False
|
| 95 |
return True
|
| 96 |
|
| 97 |
-
def walk(self
|
| 98 |
-
"""Walks the local repository path and yields a tuple of (
|
| 99 |
The filepath is relative to the root of the repository (e.g. "org/repo/your/file/path.py").
|
| 100 |
|
| 101 |
Args:
|
| 102 |
included_extensions: Optional set of extensions to include.
|
| 103 |
excluded_extensions: Optional set of extensions to exclude.
|
| 104 |
-
log_dir: Optional directory where to log the included and excluded files.
|
| 105 |
"""
|
| 106 |
# We will keep apending to these files during the iteration, so we need to clear them first.
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
os.remove(excluded_log_file)
|
| 115 |
|
| 116 |
for root, _, files in os.walk(self.local_path):
|
| 117 |
file_paths = [os.path.join(root, file) for file in files]
|
| 118 |
included_file_paths = [f for f in file_paths if self._should_include(f)]
|
| 119 |
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
f.write(path + "\n")
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
|
| 130 |
for file_path in included_file_paths:
|
| 131 |
with open(file_path, "r") as f:
|
|
@@ -134,9 +153,14 @@ class RepoManager:
|
|
| 134 |
except UnicodeDecodeError:
|
| 135 |
logging.warning("Unable to decode file %s. Skipping.", file_path)
|
| 136 |
continue
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
"""Converts a repository file path to a GitHub link."""
|
| 141 |
-
file_path = file_path[len(self.repo_id) :]
|
| 142 |
return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
|
|
|
|
| 2 |
|
| 3 |
import logging
|
| 4 |
import os
|
| 5 |
+
from abc import abstractmethod
|
| 6 |
from functools import cached_property
|
| 7 |
+
from typing import Any, Dict, Generator, Tuple
|
| 8 |
|
| 9 |
import requests
|
| 10 |
from git import GitCommandError, Repo
|
| 11 |
|
| 12 |
|
| 13 |
+
class DataManager:
|
| 14 |
+
def __init__(self, dataset_id: str):
|
| 15 |
+
self.dataset_id = dataset_id
|
| 16 |
+
|
| 17 |
+
@abstractmethod
|
| 18 |
+
def download(self) -> bool:
|
| 19 |
+
"""Downloads the data from a remote location."""
|
| 20 |
+
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def walk(self) -> Generator[Tuple[Any, Dict], None, None]:
|
| 23 |
+
"""Yields a tuple of (data, metadata) for each data item in the dataset."""
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class GitHubRepoManager(DataManager):
|
| 27 |
"""Class to manage a local clone of a GitHub repository."""
|
| 28 |
|
| 29 |
def __init__(
|
|
|
|
| 38 |
repo_id: The identifier of the repository in owner/repo format, e.g. "Storia-AI/repo2vec".
|
| 39 |
local_dir: The local directory where the repository will be cloned.
|
| 40 |
"""
|
| 41 |
+
super().__init__(dataset_id=repo_id)
|
| 42 |
self.repo_id = repo_id
|
| 43 |
+
|
| 44 |
self.local_dir = local_dir or "/tmp/"
|
| 45 |
if not os.path.exists(self.local_dir):
|
| 46 |
os.makedirs(self.local_dir)
|
| 47 |
self.local_path = os.path.join(self.local_dir, repo_id)
|
| 48 |
+
|
| 49 |
+
self.log_dir = os.path.join(self.local_dir, "logs", repo_id)
|
| 50 |
+
if not os.path.exists(self.log_dir):
|
| 51 |
+
os.makedirs(self.log_dir)
|
| 52 |
+
|
| 53 |
self.access_token = os.getenv("GITHUB_TOKEN")
|
| 54 |
self.included_extensions = included_extensions
|
| 55 |
self.excluded_extensions = excluded_extensions
|
|
|
|
| 80 |
branch = "main"
|
| 81 |
return branch
|
| 82 |
|
| 83 |
+
def download(self) -> bool:
|
| 84 |
"""Clones the repository to the local directory, if it's not already cloned."""
|
| 85 |
if os.path.exists(self.local_path):
|
| 86 |
# The repository is already cloned.
|
|
|
|
| 116 |
return False
|
| 117 |
return True
|
| 118 |
|
| 119 |
+
def walk(self) -> Generator[Tuple[Any, Dict], None, None]:
|
| 120 |
+
"""Walks the local repository path and yields a tuple of (content, metadata) for each file.
|
| 121 |
The filepath is relative to the root of the repository (e.g. "org/repo/your/file/path.py").
|
| 122 |
|
| 123 |
Args:
|
| 124 |
included_extensions: Optional set of extensions to include.
|
| 125 |
excluded_extensions: Optional set of extensions to exclude.
|
|
|
|
| 126 |
"""
|
| 127 |
# We will keep apending to these files during the iteration, so we need to clear them first.
|
| 128 |
+
repo_name = self.repo_id.replace("/", "_")
|
| 129 |
+
included_log_file = os.path.join(self.log_dir, f"included_{repo_name}.txt")
|
| 130 |
+
excluded_log_file = os.path.join(self.log_dir, f"excluded_{repo_name}.txt")
|
| 131 |
+
if os.path.exists(included_log_file):
|
| 132 |
+
os.remove(included_log_file)
|
| 133 |
+
if os.path.exists(excluded_log_file):
|
| 134 |
+
os.remove(excluded_log_file)
|
|
|
|
| 135 |
|
| 136 |
for root, _, files in os.walk(self.local_path):
|
| 137 |
file_paths = [os.path.join(root, file) for file in files]
|
| 138 |
included_file_paths = [f for f in file_paths if self._should_include(f)]
|
| 139 |
|
| 140 |
+
with open(included_log_file, "a") as f:
|
| 141 |
+
for path in included_file_paths:
|
| 142 |
+
f.write(path + "\n")
|
|
|
|
| 143 |
|
| 144 |
+
excluded_file_paths = set(file_paths).difference(set(included_file_paths))
|
| 145 |
+
with open(excluded_log_file, "a") as f:
|
| 146 |
+
for path in excluded_file_paths:
|
| 147 |
+
f.write(path + "\n")
|
| 148 |
|
| 149 |
for file_path in included_file_paths:
|
| 150 |
with open(file_path, "r") as f:
|
|
|
|
| 153 |
except UnicodeDecodeError:
|
| 154 |
logging.warning("Unable to decode file %s. Skipping.", file_path)
|
| 155 |
continue
|
| 156 |
+
relative_file_path = file_path[len(self.local_dir) + 1 :]
|
| 157 |
+
metadata = {
|
| 158 |
+
"file_path": relative_file_path,
|
| 159 |
+
"url": self.url_for_file(relative_file_path),
|
| 160 |
+
}
|
| 161 |
+
yield contents, metadata
|
| 162 |
+
|
| 163 |
+
def url_for_file(self, file_path: str) -> str:
|
| 164 |
"""Converts a repository file path to a GitHub link."""
|
| 165 |
+
file_path = file_path[len(self.repo_id) + 1 :]
|
| 166 |
return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
|
src/embedder.py
CHANGED
|
@@ -5,23 +5,23 @@ import logging
|
|
| 5 |
import os
|
| 6 |
from abc import ABC, abstractmethod
|
| 7 |
from collections import Counter
|
| 8 |
-
from typing import Dict, Generator, List, Tuple
|
| 9 |
|
| 10 |
import marqo
|
| 11 |
from openai import OpenAI
|
| 12 |
|
| 13 |
from chunker import Chunk, Chunker
|
| 14 |
-
from
|
| 15 |
|
| 16 |
Vector = Tuple[Dict, List[float]] # (metadata, embedding)
|
| 17 |
|
| 18 |
|
| 19 |
class BatchEmbedder(ABC):
|
| 20 |
-
"""Abstract class for batch embedding of a
|
| 21 |
|
| 22 |
@abstractmethod
|
| 23 |
-
def
|
| 24 |
-
"""Issues batch embedding jobs for the entire
|
| 25 |
|
| 26 |
@abstractmethod
|
| 27 |
def embeddings_are_ready(self) -> bool:
|
|
@@ -29,16 +29,16 @@ class BatchEmbedder(ABC):
|
|
| 29 |
|
| 30 |
@abstractmethod
|
| 31 |
def download_embeddings(self) -> Generator[Vector, None, None]:
|
| 32 |
-
"""Yields (chunk_metadata, embedding) pairs for each chunk in the
|
| 33 |
|
| 34 |
|
| 35 |
class OpenAIBatchEmbedder(BatchEmbedder):
|
| 36 |
"""Batch embedder that calls OpenAI. See https://platform.openai.com/docs/guides/batch/overview."""
|
| 37 |
|
| 38 |
def __init__(
|
| 39 |
-
self,
|
| 40 |
):
|
| 41 |
-
self.
|
| 42 |
self.chunker = chunker
|
| 43 |
self.local_dir = local_dir
|
| 44 |
self.embedding_model = embedding_model
|
|
@@ -47,17 +47,17 @@ class OpenAIBatchEmbedder(BatchEmbedder):
|
|
| 47 |
self.openai_batch_ids = {}
|
| 48 |
self.client = OpenAI()
|
| 49 |
|
| 50 |
-
def
|
| 51 |
-
"""Issues batch embedding jobs for the entire
|
| 52 |
if self.openai_batch_ids:
|
| 53 |
raise ValueError("Embeddings are in progress.")
|
| 54 |
|
| 55 |
batch = []
|
| 56 |
chunk_count = 0
|
| 57 |
-
|
| 58 |
|
| 59 |
-
for
|
| 60 |
-
chunks = self.chunker.chunk(
|
| 61 |
chunk_count += len(chunks)
|
| 62 |
batch.extend(chunks)
|
| 63 |
|
|
@@ -65,9 +65,9 @@ class OpenAIBatchEmbedder(BatchEmbedder):
|
|
| 65 |
for i in range(0, len(batch), chunks_per_batch):
|
| 66 |
sub_batch = batch[i : i + chunks_per_batch]
|
| 67 |
openai_batch_id = self._issue_job_for_chunks(
|
| 68 |
-
sub_batch, batch_id=f"{
|
| 69 |
)
|
| 70 |
-
self.openai_batch_ids[openai_batch_id] = [chunk.
|
| 71 |
if max_embedding_jobs and len(self.openai_batch_ids) >= max_embedding_jobs:
|
| 72 |
logging.info("Reached the maximum number of embedding jobs. Stopping.")
|
| 73 |
return
|
|
@@ -75,8 +75,8 @@ class OpenAIBatchEmbedder(BatchEmbedder):
|
|
| 75 |
|
| 76 |
# Finally, commit the last batch.
|
| 77 |
if batch:
|
| 78 |
-
openai_batch_id = self._issue_job_for_chunks(batch, batch_id=f"{
|
| 79 |
-
self.openai_batch_ids[openai_batch_id] = [chunk.
|
| 80 |
logging.info("Issued %d jobs for %d chunks.", len(self.openai_batch_ids), chunk_count)
|
| 81 |
|
| 82 |
# Save the job IDs to a file, just in case this script is terminated by mistake.
|
|
@@ -97,7 +97,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
|
|
| 97 |
return are_ready
|
| 98 |
|
| 99 |
def download_embeddings(self) -> Generator[Vector, None, None]:
|
| 100 |
-
"""Yield a (chunk_metadata, embedding) pair for each chunk in the
|
| 101 |
job_ids = self.openai_batch_ids.keys()
|
| 102 |
statuses = [self.client.batches.retrieve(job_id.strip()) for job_id in job_ids]
|
| 103 |
|
|
@@ -164,17 +164,22 @@ class OpenAIBatchEmbedder(BatchEmbedder):
|
|
| 164 |
f.write("\n")
|
| 165 |
|
| 166 |
@staticmethod
|
| 167 |
-
def _chunks_to_request(chunks: List[Chunk], batch_id: str, model: str, dimensions: int):
|
| 168 |
"""Convert a list of chunks to a batch request."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
return {
|
| 170 |
"custom_id": batch_id,
|
| 171 |
"method": "POST",
|
| 172 |
"url": "/v1/embeddings",
|
| 173 |
-
"body":
|
| 174 |
-
"model": model,
|
| 175 |
-
"dimensions": dimensions,
|
| 176 |
-
"input": [chunk.content for chunk in chunks],
|
| 177 |
-
},
|
| 178 |
}
|
| 179 |
|
| 180 |
|
|
@@ -184,8 +189,8 @@ class MarqoEmbedder(BatchEmbedder):
|
|
| 184 |
Embeddings can be stored locally (in which case `url` the constructor should point to localhost) or in the cloud.
|
| 185 |
"""
|
| 186 |
|
| 187 |
-
def __init__(self,
|
| 188 |
-
self.
|
| 189 |
self.chunker = chunker
|
| 190 |
self.client = marqo.Client(url=url)
|
| 191 |
self.index = self.client.index(index_name)
|
|
@@ -194,16 +199,16 @@ class MarqoEmbedder(BatchEmbedder):
|
|
| 194 |
if not index_name in all_index_names:
|
| 195 |
self.client.create_index(index_name, model=model)
|
| 196 |
|
| 197 |
-
def
|
| 198 |
-
"""Issues batch embedding jobs for the entire
|
| 199 |
if chunks_per_batch > 64:
|
| 200 |
raise ValueError("Marqo enforces a limit of 64 chunks per batch.")
|
| 201 |
|
| 202 |
chunk_count = 0
|
| 203 |
batch = []
|
| 204 |
|
| 205 |
-
for
|
| 206 |
-
chunks = self.chunker.chunk(
|
| 207 |
chunk_count += len(chunks)
|
| 208 |
batch.extend(chunks)
|
| 209 |
|
|
@@ -212,7 +217,7 @@ class MarqoEmbedder(BatchEmbedder):
|
|
| 212 |
sub_batch = batch[i : i + chunks_per_batch]
|
| 213 |
logging.info("Indexing %d chunks...", len(sub_batch))
|
| 214 |
self.index.add_documents(
|
| 215 |
-
documents=[chunk.
|
| 216 |
tensor_fields=["text"],
|
| 217 |
)
|
| 218 |
|
|
@@ -223,16 +228,27 @@ class MarqoEmbedder(BatchEmbedder):
|
|
| 223 |
|
| 224 |
# Finally, commit the last batch.
|
| 225 |
if batch:
|
| 226 |
-
self.index.add_documents(documents=[chunk.
|
| 227 |
logging.info(f"Successfully embedded {chunk_count} chunks.")
|
| 228 |
|
| 229 |
def embeddings_are_ready(self) -> bool:
|
| 230 |
"""Checks whether the batch embedding jobs are done."""
|
| 231 |
-
# Marqo indexes documents synchronously, so once
|
| 232 |
return True
|
| 233 |
|
| 234 |
def download_embeddings(self) -> Generator[Vector, None, None]:
|
| 235 |
-
"""Yields (chunk_metadata, embedding) pairs for each chunk in the
|
| 236 |
# Marqo stores embeddings as they are created, so they're already in the vector store. No need to download them
|
| 237 |
# as we would with e.g. OpenAI, Cohere, or some other cloud-based embedding service.
|
| 238 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import os
|
| 6 |
from abc import ABC, abstractmethod
|
| 7 |
from collections import Counter
|
| 8 |
+
from typing import Dict, Generator, List, Optional, Tuple
|
| 9 |
|
| 10 |
import marqo
|
| 11 |
from openai import OpenAI
|
| 12 |
|
| 13 |
from chunker import Chunk, Chunker
|
| 14 |
+
from data_manager import DataManager
|
| 15 |
|
| 16 |
Vector = Tuple[Dict, List[float]] # (metadata, embedding)
|
| 17 |
|
| 18 |
|
| 19 |
class BatchEmbedder(ABC):
|
| 20 |
+
"""Abstract class for batch embedding of a dataset."""
|
| 21 |
|
| 22 |
@abstractmethod
|
| 23 |
+
def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
|
| 24 |
+
"""Issues batch embedding jobs for the entire dataset."""
|
| 25 |
|
| 26 |
@abstractmethod
|
| 27 |
def embeddings_are_ready(self) -> bool:
|
|
|
|
| 29 |
|
| 30 |
@abstractmethod
|
| 31 |
def download_embeddings(self) -> Generator[Vector, None, None]:
|
| 32 |
+
"""Yields (chunk_metadata, embedding) pairs for each chunk in the dataset."""
|
| 33 |
|
| 34 |
|
| 35 |
class OpenAIBatchEmbedder(BatchEmbedder):
|
| 36 |
"""Batch embedder that calls OpenAI. See https://platform.openai.com/docs/guides/batch/overview."""
|
| 37 |
|
| 38 |
def __init__(
|
| 39 |
+
self, data_manager: DataManager, chunker: Chunker, local_dir: str, embedding_model: str, embedding_size: int
|
| 40 |
):
|
| 41 |
+
self.data_manager = data_manager
|
| 42 |
self.chunker = chunker
|
| 43 |
self.local_dir = local_dir
|
| 44 |
self.embedding_model = embedding_model
|
|
|
|
| 47 |
self.openai_batch_ids = {}
|
| 48 |
self.client = OpenAI()
|
| 49 |
|
| 50 |
+
def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
|
| 51 |
+
"""Issues batch embedding jobs for the entire dataset."""
|
| 52 |
if self.openai_batch_ids:
|
| 53 |
raise ValueError("Embeddings are in progress.")
|
| 54 |
|
| 55 |
batch = []
|
| 56 |
chunk_count = 0
|
| 57 |
+
dataset_name = self.data_manager.dataset_id.split("/")[-1]
|
| 58 |
|
| 59 |
+
for content, metadata in self.data_manager.walk():
|
| 60 |
+
chunks = self.chunker.chunk(content, metadata)
|
| 61 |
chunk_count += len(chunks)
|
| 62 |
batch.extend(chunks)
|
| 63 |
|
|
|
|
| 65 |
for i in range(0, len(batch), chunks_per_batch):
|
| 66 |
sub_batch = batch[i : i + chunks_per_batch]
|
| 67 |
openai_batch_id = self._issue_job_for_chunks(
|
| 68 |
+
sub_batch, batch_id=f"{dataset_name}/{len(self.openai_batch_ids)}"
|
| 69 |
)
|
| 70 |
+
self.openai_batch_ids[openai_batch_id] = [chunk.metadata for chunk in sub_batch]
|
| 71 |
if max_embedding_jobs and len(self.openai_batch_ids) >= max_embedding_jobs:
|
| 72 |
logging.info("Reached the maximum number of embedding jobs. Stopping.")
|
| 73 |
return
|
|
|
|
| 75 |
|
| 76 |
# Finally, commit the last batch.
|
| 77 |
if batch:
|
| 78 |
+
openai_batch_id = self._issue_job_for_chunks(batch, batch_id=f"{dataset_name}/{len(self.openai_batch_ids)}")
|
| 79 |
+
self.openai_batch_ids[openai_batch_id] = [chunk.metadata for chunk in batch]
|
| 80 |
logging.info("Issued %d jobs for %d chunks.", len(self.openai_batch_ids), chunk_count)
|
| 81 |
|
| 82 |
# Save the job IDs to a file, just in case this script is terminated by mistake.
|
|
|
|
| 97 |
return are_ready
|
| 98 |
|
| 99 |
def download_embeddings(self) -> Generator[Vector, None, None]:
|
| 100 |
+
"""Yield a (chunk_metadata, embedding) pair for each chunk in the dataset."""
|
| 101 |
job_ids = self.openai_batch_ids.keys()
|
| 102 |
statuses = [self.client.batches.retrieve(job_id.strip()) for job_id in job_ids]
|
| 103 |
|
|
|
|
| 164 |
f.write("\n")
|
| 165 |
|
| 166 |
@staticmethod
|
| 167 |
+
def _chunks_to_request(chunks: List[Chunk], batch_id: str, model: str, dimensions: Optional[int] = None) -> Dict:
|
| 168 |
"""Convert a list of chunks to a batch request."""
|
| 169 |
+
body = {
|
| 170 |
+
"model": model,
|
| 171 |
+
"input": [chunk.content for chunk in chunks],
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
# These are the only two models that support a dynamic embedding size.
|
| 175 |
+
if model in ["text-embedding-3-small", "text-embedding-3-large"] and dimensions is not None:
|
| 176 |
+
body["dimensions"] = dimensions
|
| 177 |
+
|
| 178 |
return {
|
| 179 |
"custom_id": batch_id,
|
| 180 |
"method": "POST",
|
| 181 |
"url": "/v1/embeddings",
|
| 182 |
+
"body": body,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
}
|
| 184 |
|
| 185 |
|
|
|
|
| 189 |
Embeddings can be stored locally (in which case `url` the constructor should point to localhost) or in the cloud.
|
| 190 |
"""
|
| 191 |
|
| 192 |
+
def __init__(self, data_manager: DataManager, chunker: Chunker, index_name: str, url: str, model="hf/e5-base-v2"):
|
| 193 |
+
self.data_manager = data_manager
|
| 194 |
self.chunker = chunker
|
| 195 |
self.client = marqo.Client(url=url)
|
| 196 |
self.index = self.client.index(index_name)
|
|
|
|
| 199 |
if not index_name in all_index_names:
|
| 200 |
self.client.create_index(index_name, model=model)
|
| 201 |
|
| 202 |
+
def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
|
| 203 |
+
"""Issues batch embedding jobs for the entire dataset."""
|
| 204 |
if chunks_per_batch > 64:
|
| 205 |
raise ValueError("Marqo enforces a limit of 64 chunks per batch.")
|
| 206 |
|
| 207 |
chunk_count = 0
|
| 208 |
batch = []
|
| 209 |
|
| 210 |
+
for content, metadata in self.data_manager.walk():
|
| 211 |
+
chunks = self.chunker.chunk(content, metadata)
|
| 212 |
chunk_count += len(chunks)
|
| 213 |
batch.extend(chunks)
|
| 214 |
|
|
|
|
| 217 |
sub_batch = batch[i : i + chunks_per_batch]
|
| 218 |
logging.info("Indexing %d chunks...", len(sub_batch))
|
| 219 |
self.index.add_documents(
|
| 220 |
+
documents=[chunk.metadata for chunk in sub_batch],
|
| 221 |
tensor_fields=["text"],
|
| 222 |
)
|
| 223 |
|
|
|
|
| 228 |
|
| 229 |
# Finally, commit the last batch.
|
| 230 |
if batch:
|
| 231 |
+
self.index.add_documents(documents=[chunk.metadata for chunk in batch], tensor_fields=["text"])
|
| 232 |
logging.info(f"Successfully embedded {chunk_count} chunks.")
|
| 233 |
|
| 234 |
def embeddings_are_ready(self) -> bool:
|
| 235 |
"""Checks whether the batch embedding jobs are done."""
|
| 236 |
+
# Marqo indexes documents synchronously, so once embed_dataset() returns, the embeddings are ready.
|
| 237 |
return True
|
| 238 |
|
| 239 |
def download_embeddings(self) -> Generator[Vector, None, None]:
|
| 240 |
+
"""Yields (chunk_metadata, embedding) pairs for each chunk in the dataset."""
|
| 241 |
# Marqo stores embeddings as they are created, so they're already in the vector store. No need to download them
|
| 242 |
# as we would with e.g. OpenAI, Cohere, or some other cloud-based embedding service.
|
| 243 |
return []
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def build_batch_embedder_from_flags(data_manager: DataManager, chunker: Chunker, args) -> BatchEmbedder:
|
| 247 |
+
if args.embedder_type == "openai":
|
| 248 |
+
return OpenAIBatchEmbedder(data_manager, chunker, args.local_dir, args.embedding_model, args.embedding_size)
|
| 249 |
+
elif args.embedder_type == "marqo":
|
| 250 |
+
return MarqoEmbedder(
|
| 251 |
+
data_manager, chunker, index_name=args.index_name, url=args.marqo_url, model=args.embedding_model
|
| 252 |
+
)
|
| 253 |
+
else:
|
| 254 |
+
raise ValueError(f"Unrecognized embedder type {args.embedder_type}")
|
src/github.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GitHub-specific implementations for DataManager and Chunker."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any, Dict, Generator, List, Tuple
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import requests
|
| 9 |
+
import tiktoken
|
| 10 |
+
|
| 11 |
+
from chunker import Chunk, Chunker
|
| 12 |
+
from data_manager import DataManager
|
| 13 |
+
|
| 14 |
+
tokenizer = tiktoken.get_encoding("cl100k_base")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class GitHubIssueComment:
|
| 19 |
+
"""A comment on a GitHub issue."""
|
| 20 |
+
|
| 21 |
+
url: str
|
| 22 |
+
html_url: str
|
| 23 |
+
body: str
|
| 24 |
+
|
| 25 |
+
@property
|
| 26 |
+
def pretty(self):
|
| 27 |
+
return f"""## Comment: {self.body}"""
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class GitHubIssue:
|
| 32 |
+
"""A GitHub issue."""
|
| 33 |
+
|
| 34 |
+
url: str
|
| 35 |
+
html_url: str
|
| 36 |
+
title: str
|
| 37 |
+
body: str
|
| 38 |
+
comments: List[GitHubIssueComment]
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def pretty(self):
|
| 42 |
+
# Do not include the comments.
|
| 43 |
+
return f"# Issue: {self.title}\n{self.body}"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class GitHubIssuesManager(DataManager):
|
| 47 |
+
"""Class to manage the GitHub issues of a particular repository."""
|
| 48 |
+
|
| 49 |
+
def __init__(self, repo_id: str, max_issues: int = None):
|
| 50 |
+
super().__init__(dataset_id=repo_id + "/issues")
|
| 51 |
+
self.repo_id = repo_id
|
| 52 |
+
self.max_issues = max_issues
|
| 53 |
+
self.access_token = os.getenv("GITHUB_TOKEN")
|
| 54 |
+
if not self.access_token:
|
| 55 |
+
raise ValueError("Please set the GITHUB_TOKEN environment variable when indexing GitHub issues.")
|
| 56 |
+
self.issues = []
|
| 57 |
+
|
| 58 |
+
def download(self) -> bool:
|
| 59 |
+
"""Downloads all open issues from a GitHub repository (including the comments)."""
|
| 60 |
+
per_page = min(self.max_issues or 100, 100) # 100 is maximum per page
|
| 61 |
+
url = f"https://api.github.com/repos/{self.repo_id}/issues?per_page={per_page}"
|
| 62 |
+
while url:
|
| 63 |
+
print(f"Fetching issues from {url}")
|
| 64 |
+
response = self._get_page_of_issues(url)
|
| 65 |
+
response.raise_for_status()
|
| 66 |
+
for issue in response.json():
|
| 67 |
+
if not "pull_request" in issue:
|
| 68 |
+
self.issues.append(
|
| 69 |
+
GitHubIssue(
|
| 70 |
+
url=issue["url"],
|
| 71 |
+
html_url=issue["html_url"],
|
| 72 |
+
title=issue["title"],
|
| 73 |
+
# When there's no body, issue["body"] is None.
|
| 74 |
+
body=issue["body"] or "",
|
| 75 |
+
comments=self._get_comments(issue["comments_url"]),
|
| 76 |
+
)
|
| 77 |
+
)
|
| 78 |
+
if self.max_issues and len(self.issues) >= self.max_issues:
|
| 79 |
+
break
|
| 80 |
+
url = GitHubIssuesManager._get_next_link_from_header(response)
|
| 81 |
+
return True
|
| 82 |
+
|
| 83 |
+
def walk(self) -> Generator[Tuple[Any, Dict], None, None]:
|
| 84 |
+
"""Yields a tuple of (issue_content, issue_metadata) for each GitHub issue in the repository."""
|
| 85 |
+
for issue in self.issues:
|
| 86 |
+
yield issue, {} # empty metadata
|
| 87 |
+
|
| 88 |
+
@staticmethod
|
| 89 |
+
def _get_next_link_from_header(response):
|
| 90 |
+
"""
|
| 91 |
+
Given a response from a paginated request, extracts the URL of the next page.
|
| 92 |
+
|
| 93 |
+
Example:
|
| 94 |
+
response.headers.get("link") = '<https://api.github.com/repositories/2503910/issues?per_page=10&page=2>; rel="next", <https://api.github.com/repositories/2503910/issues?per_page=10&page=2>; rel="last"'
|
| 95 |
+
get_next_link_from_header(response) = 'https://api.github.com/repositories/2503910/issues?per_page=10&page=2'
|
| 96 |
+
"""
|
| 97 |
+
link_header = response.headers.get("link")
|
| 98 |
+
if link_header:
|
| 99 |
+
links = link_header.split(", ")
|
| 100 |
+
for link in links:
|
| 101 |
+
url, rel = link.split("; ")
|
| 102 |
+
url = url[1:-1] # The URL is enclosed in angle brackets
|
| 103 |
+
rel = rel[5:-1] # e.g. rel="next" -> next
|
| 104 |
+
if rel == "next":
|
| 105 |
+
return url
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
def _get_page_of_issues(self, url):
|
| 109 |
+
"""Downloads a single page of issues. Note that GitHub uses pagination for long lists of objects."""
|
| 110 |
+
return requests.get(
|
| 111 |
+
url,
|
| 112 |
+
headers={
|
| 113 |
+
"Authorization": f"Bearer {self.access_token}",
|
| 114 |
+
"X-GitHub-Api-Version": "2022-11-28",
|
| 115 |
+
},
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def _get_comments(self, comments_url) -> List[GitHubIssueComment]:
|
| 119 |
+
"""Downloads all the comments associated with an issue; returns an empty list if the request times out."""
|
| 120 |
+
try:
|
| 121 |
+
response = requests.get(
|
| 122 |
+
comments_url,
|
| 123 |
+
headers={
|
| 124 |
+
"Authorization": f"Bearer {self.access_token}",
|
| 125 |
+
"X-GitHub-Api-Version": "2022-11-28",
|
| 126 |
+
},
|
| 127 |
+
)
|
| 128 |
+
except requests.exceptions.ConnectionTimeout:
|
| 129 |
+
logging.warn(f"Timeout fetching comments from {comments_url}")
|
| 130 |
+
return []
|
| 131 |
+
comments = []
|
| 132 |
+
for comment in response.json():
|
| 133 |
+
comments.append(
|
| 134 |
+
GitHubIssueComment(
|
| 135 |
+
url=comment["url"],
|
| 136 |
+
html_url=comment["html_url"],
|
| 137 |
+
body=comment["body"],
|
| 138 |
+
)
|
| 139 |
+
)
|
| 140 |
+
return comments
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@dataclass
|
| 144 |
+
class IssueChunk(Chunk):
|
| 145 |
+
"""A chunk form a GitHub issue with a contiguous (sub)set of comments.
|
| 146 |
+
|
| 147 |
+
Note that, in comparison to FileChunk, its properties are not cached. We want to allow fields to be changed in place
|
| 148 |
+
and have e.g. the token count be recomputed. Compared to files, GitHub issues are typically smaller, so the overhead
|
| 149 |
+
is less problematic.
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
issue: GitHubIssue
|
| 153 |
+
start_comment: int
|
| 154 |
+
end_comment: int # exclusive
|
| 155 |
+
|
| 156 |
+
@property
|
| 157 |
+
def content(self) -> str:
|
| 158 |
+
"""The title of the issue, followed by the comments in the chunk."""
|
| 159 |
+
if self.start_comment == 0:
|
| 160 |
+
# This is the first subsequence of comments. We'll include the entire body of the issue.
|
| 161 |
+
issue_str = self.issue.pretty
|
| 162 |
+
else:
|
| 163 |
+
# This is a middle subsequence of comments. We'll only include the title of the issue.
|
| 164 |
+
issue_str = f"# Issue: {self.issue.title}"
|
| 165 |
+
# Now add the comments themselves.
|
| 166 |
+
comments = self.issue.comments[self.start_comment : self.end_comment]
|
| 167 |
+
comments_str = "\n\n".join([comment.pretty for comment in comments])
|
| 168 |
+
return issue_str + "\n\n" + comments_str
|
| 169 |
+
|
| 170 |
+
@property
|
| 171 |
+
def metadata(self):
|
| 172 |
+
"""Converts the chunk to a dictionary that can be passed to a vector store."""
|
| 173 |
+
return {
|
| 174 |
+
"id": f"{self.issue.html_url}_{self.start_comment}_{self.end_comment}",
|
| 175 |
+
"url": self.issue.html_url,
|
| 176 |
+
"start_comment": self.start_comment,
|
| 177 |
+
"end_comment": self.end_comment,
|
| 178 |
+
# Note to developer: When choosing a large chunk size, you might exceed the vector store's metadata
|
| 179 |
+
# size limit. In that case, you can simply store the start/end comment indices above, and fetch the
|
| 180 |
+
# content of the issue on demand from the URL.
|
| 181 |
+
"text": self.content,
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
@property
|
| 185 |
+
def num_tokens(self):
|
| 186 |
+
"""Number of tokens in this chunk."""
|
| 187 |
+
return len(tokenizer.encode(self.content, disallowed_special=()))
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class GitHubIssuesChunker(Chunker):
|
| 191 |
+
"""Chunks a GitHub issue into smaller pieces of contiguous (sub)sets of comments."""
|
| 192 |
+
|
| 193 |
+
def __init__(self, max_tokens: int):
|
| 194 |
+
self.max_tokens = max_tokens
|
| 195 |
+
|
| 196 |
+
def chunk(self, content: Any, metadata: Dict) -> List[Chunk]:
|
| 197 |
+
"""Chunks a GitHub issue into subsequences of comments."""
|
| 198 |
+
del metadata # The metadata of the input issue is unused.
|
| 199 |
+
|
| 200 |
+
issue = content # Rename for clarity.
|
| 201 |
+
if not isinstance(issue, GitHubIssue):
|
| 202 |
+
raise ValueError(f"Expected a GitHubIssue, got {type(issue)}.")
|
| 203 |
+
|
| 204 |
+
chunks = []
|
| 205 |
+
|
| 206 |
+
# First, create a chunk for the issue body.
|
| 207 |
+
issue_body_chunk = IssueChunk(issue, 0, 0)
|
| 208 |
+
chunks.append(issue_body_chunk)
|
| 209 |
+
|
| 210 |
+
for comment_idx, comment in enumerate(issue.comments):
|
| 211 |
+
# This is just approximate, because when we actually add a comment to the chunk there might be some extra
|
| 212 |
+
# tokens, like a "Comment:" prefix.
|
| 213 |
+
approx_comment_size = len(tokenizer.encode(comment.body, disallowed_special=())) + 20 # 20 for buffer
|
| 214 |
+
|
| 215 |
+
if chunks[-1].num_tokens + approx_comment_size > self.max_tokens:
|
| 216 |
+
# Create a new chunk starting from this comment.
|
| 217 |
+
chunks.append(
|
| 218 |
+
IssueChunk(
|
| 219 |
+
issue=issue,
|
| 220 |
+
start_comment=comment_idx,
|
| 221 |
+
end_comment=comment_idx + 1,
|
| 222 |
+
))
|
| 223 |
+
else:
|
| 224 |
+
# Add the comment to the existing chunk.
|
| 225 |
+
chunks[-1].end_comment = comment_idx + 1
|
| 226 |
+
return chunks
|
src/index.py
CHANGED
|
@@ -4,9 +4,10 @@ import argparse
|
|
| 4 |
import logging
|
| 5 |
import time
|
| 6 |
|
| 7 |
-
from chunker import
|
| 8 |
-
from
|
| 9 |
-
from
|
|
|
|
| 10 |
from vector_store import build_from_args
|
| 11 |
|
| 12 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -31,43 +32,42 @@ def _read_extensions(path):
|
|
| 31 |
|
| 32 |
|
| 33 |
def main():
|
| 34 |
-
parser = argparse.ArgumentParser(description="Batch-embeds a repository")
|
| 35 |
parser.add_argument("repo_id", help="The ID of the repository to index")
|
| 36 |
-
parser.add_argument("--
|
| 37 |
parser.add_argument(
|
| 38 |
-
"--
|
| 39 |
type=str,
|
| 40 |
default=None,
|
| 41 |
help="The embedding model. Defaults to `text-embedding-ada-002` for OpenAI and `hf/e5-base-v2` for Marqo.",
|
| 42 |
)
|
| 43 |
parser.add_argument(
|
| 44 |
-
"--
|
| 45 |
type=int,
|
| 46 |
default=None,
|
| 47 |
-
help="The embedding size to use for OpenAI
|
| 48 |
-
"
|
| 49 |
-
"No need to specify an embedding size for Marqo, since the embedding model determines it.",
|
| 50 |
)
|
| 51 |
-
parser.add_argument("--
|
| 52 |
parser.add_argument(
|
| 53 |
-
"--
|
| 54 |
default="repos",
|
| 55 |
help="The local directory to store the repository",
|
| 56 |
)
|
| 57 |
parser.add_argument(
|
| 58 |
-
"--
|
| 59 |
type=int,
|
| 60 |
default=800,
|
| 61 |
help="https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.",
|
| 62 |
)
|
| 63 |
parser.add_argument(
|
| 64 |
-
"--
|
| 65 |
type=int,
|
| 66 |
default=2000,
|
| 67 |
help="Maximum chunks per batch. We recommend 2000 for the OpenAI embedder. Marqo enforces a limit of 64.",
|
| 68 |
)
|
| 69 |
parser.add_argument(
|
| 70 |
-
"--
|
| 71 |
required=True,
|
| 72 |
help="Vector store index name. For Pinecone, make sure to create it with the right embedding size.",
|
| 73 |
)
|
|
@@ -81,16 +81,30 @@ def main():
|
|
| 81 |
help="Path to a file containing a list of extensions to exclude. One extension per line.",
|
| 82 |
)
|
| 83 |
parser.add_argument(
|
| 84 |
-
"--
|
| 85 |
type=int,
|
| 86 |
help="Maximum number of embedding jobs to run. Specifying this might result in "
|
| 87 |
"indexing only part of the repository, but prevents you from burning through OpenAI credits.",
|
| 88 |
)
|
| 89 |
parser.add_argument(
|
| 90 |
-
"--
|
| 91 |
default="http://localhost:8882",
|
| 92 |
help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
|
| 93 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
args = parser.parse_args()
|
| 95 |
|
| 96 |
# Validate embedder and vector store compatibility.
|
|
@@ -111,56 +125,81 @@ def main():
|
|
| 111 |
parser.error(f"The maximum number of chunks per job is {MAX_TOKENS_PER_JOB}.")
|
| 112 |
if args.include and args.exclude:
|
| 113 |
parser.error("At most one of --include and --exclude can be specified.")
|
|
|
|
|
|
|
| 114 |
|
| 115 |
# Set default values based on other arguments
|
| 116 |
if args.embedding_model is None:
|
| 117 |
args.embedding_model = "text-embedding-ada-002" if args.embedder_type == "openai" else "hf/e5-base-v2"
|
| 118 |
if args.embedding_size is None and args.embedder_type == "openai":
|
| 119 |
args.embedding_size = OPENAI_DEFAULT_EMBEDDING_SIZE.get(args.embedding_model)
|
| 120 |
-
# No need to set embedding_size for Marqo, since the embedding model determines the embedding size.
|
| 121 |
-
logging.warn("--embedding_size is ignored for Marqo embedder.")
|
| 122 |
-
|
| 123 |
-
included_extensions = _read_extensions(args.include) if args.include else None
|
| 124 |
-
excluded_extensions = _read_extensions(args.exclude) if args.exclude else None
|
| 125 |
-
|
| 126 |
-
logging.info("Cloning the repository...")
|
| 127 |
-
repo_manager = RepoManager(
|
| 128 |
-
args.repo_id,
|
| 129 |
-
local_dir=args.local_dir,
|
| 130 |
-
included_extensions=included_extensions,
|
| 131 |
-
excluded_extensions=excluded_extensions,
|
| 132 |
-
)
|
| 133 |
-
repo_manager.clone()
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
)
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
if args.vector_store_type == "marqo":
|
| 150 |
# Marqo computes embeddings and stores them in the vector store at once, so we're done.
|
| 151 |
logging.info("Done!")
|
| 152 |
return
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
-
logging.info("Moving embeddings to the vector store...")
|
| 160 |
-
# Note to developer: Replace this with your preferred vector store.
|
| 161 |
-
vector_store = build_from_args(args)
|
| 162 |
-
vector_store.ensure_exists()
|
| 163 |
-
vector_store.upsert(embedder.download_embeddings())
|
| 164 |
logging.info("Done!")
|
| 165 |
|
| 166 |
|
|
|
|
| 4 |
import logging
|
| 5 |
import time
|
| 6 |
|
| 7 |
+
from chunker import UniversalFileChunker
|
| 8 |
+
from data_manager import GitHubRepoManager
|
| 9 |
+
from embedder import build_batch_embedder_from_flags
|
| 10 |
+
from github import GitHubIssuesChunker, GitHubIssuesManager
|
| 11 |
from vector_store import build_from_args
|
| 12 |
|
| 13 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
def main():
|
| 35 |
+
parser = argparse.ArgumentParser(description="Batch-embeds a GitHub repository and its issues.")
|
| 36 |
parser.add_argument("repo_id", help="The ID of the repository to index")
|
| 37 |
+
parser.add_argument("--embedder-type", default="openai", choices=["openai", "marqo"])
|
| 38 |
parser.add_argument(
|
| 39 |
+
"--embedding-model",
|
| 40 |
type=str,
|
| 41 |
default=None,
|
| 42 |
help="The embedding model. Defaults to `text-embedding-ada-002` for OpenAI and `hf/e5-base-v2` for Marqo.",
|
| 43 |
)
|
| 44 |
parser.add_argument(
|
| 45 |
+
"--embedding-size",
|
| 46 |
type=int,
|
| 47 |
default=None,
|
| 48 |
+
help="The embedding size to use for OpenAI text-embedding-3* models. Defaults to 1536 for small and 3072 for "
|
| 49 |
+
"large. Note that no other OpenAI models support a dynamic embedding size, nor do models used with Marqo.",
|
|
|
|
| 50 |
)
|
| 51 |
+
parser.add_argument("--vector-store-type", default="pinecone", choices=["pinecone", "marqo"])
|
| 52 |
parser.add_argument(
|
| 53 |
+
"--local-dir",
|
| 54 |
default="repos",
|
| 55 |
help="The local directory to store the repository",
|
| 56 |
)
|
| 57 |
parser.add_argument(
|
| 58 |
+
"--tokens-per-chunk",
|
| 59 |
type=int,
|
| 60 |
default=800,
|
| 61 |
help="https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.",
|
| 62 |
)
|
| 63 |
parser.add_argument(
|
| 64 |
+
"--chunks-per-batch",
|
| 65 |
type=int,
|
| 66 |
default=2000,
|
| 67 |
help="Maximum chunks per batch. We recommend 2000 for the OpenAI embedder. Marqo enforces a limit of 64.",
|
| 68 |
)
|
| 69 |
parser.add_argument(
|
| 70 |
+
"--index-name",
|
| 71 |
required=True,
|
| 72 |
help="Vector store index name. For Pinecone, make sure to create it with the right embedding size.",
|
| 73 |
)
|
|
|
|
| 81 |
help="Path to a file containing a list of extensions to exclude. One extension per line.",
|
| 82 |
)
|
| 83 |
parser.add_argument(
|
| 84 |
+
"--max-embedding-jobs",
|
| 85 |
type=int,
|
| 86 |
help="Maximum number of embedding jobs to run. Specifying this might result in "
|
| 87 |
"indexing only part of the repository, but prevents you from burning through OpenAI credits.",
|
| 88 |
)
|
| 89 |
parser.add_argument(
|
| 90 |
+
"--marqo-url",
|
| 91 |
default="http://localhost:8882",
|
| 92 |
help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
|
| 93 |
)
|
| 94 |
+
# Pass --no-index-repo in order to not index the repository.
|
| 95 |
+
parser.add_argument(
|
| 96 |
+
"--index-repo",
|
| 97 |
+
action=argparse.BooleanOptionalAction,
|
| 98 |
+
default=True,
|
| 99 |
+
help="Whether to index the repository. At least one of --index-repo and --index-issues must be True.",
|
| 100 |
+
)
|
| 101 |
+
# Pass --no-index-issues in order to not index the issues.
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
"--index-issues",
|
| 104 |
+
action=argparse.BooleanOptionalAction,
|
| 105 |
+
default=True,
|
| 106 |
+
help="Whether to index GitHub issues. At least one of --index-repo and --index-issues must be True.",
|
| 107 |
+
)
|
| 108 |
args = parser.parse_args()
|
| 109 |
|
| 110 |
# Validate embedder and vector store compatibility.
|
|
|
|
| 125 |
parser.error(f"The maximum number of chunks per job is {MAX_TOKENS_PER_JOB}.")
|
| 126 |
if args.include and args.exclude:
|
| 127 |
parser.error("At most one of --include and --exclude can be specified.")
|
| 128 |
+
if not args.index_repo and not args.index_issues:
|
| 129 |
+
parser.error("At least one of --index-repo and --index-issues must be true.")
|
| 130 |
|
| 131 |
# Set default values based on other arguments
|
| 132 |
if args.embedding_model is None:
|
| 133 |
args.embedding_model = "text-embedding-ada-002" if args.embedder_type == "openai" else "hf/e5-base-v2"
|
| 134 |
if args.embedding_size is None and args.embedder_type == "openai":
|
| 135 |
args.embedding_size = OPENAI_DEFAULT_EMBEDDING_SIZE.get(args.embedding_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
+
######################
|
| 138 |
+
# Step 1: Embeddings #
|
| 139 |
+
######################
|
| 140 |
+
|
| 141 |
+
# Index the repository.
|
| 142 |
+
repo_embedder = None
|
| 143 |
+
if args.index_repo:
|
| 144 |
+
included_extensions = _read_extensions(args.include) if args.include else None
|
| 145 |
+
excluded_extensions = _read_extensions(args.exclude) if args.exclude else None
|
| 146 |
+
|
| 147 |
+
logging.info("Cloning the repository...")
|
| 148 |
+
repo_manager = GitHubRepoManager(
|
| 149 |
+
args.repo_id,
|
| 150 |
+
local_dir=args.local_dir,
|
| 151 |
+
included_extensions=included_extensions,
|
| 152 |
+
excluded_extensions=excluded_extensions,
|
| 153 |
)
|
| 154 |
+
repo_manager.download()
|
| 155 |
+
logging.info("Embedding the repo...")
|
| 156 |
+
chunker = UniversalFileChunker(max_tokens=args.tokens_per_chunk)
|
| 157 |
+
repo_embedder = build_batch_embedder_from_flags(repo_manager, chunker, args)
|
| 158 |
+
repo_embedder.embed_dataset(args.chunks_per_batch, args.max_embedding_jobs)
|
| 159 |
+
|
| 160 |
+
# Index the GitHub issues.
|
| 161 |
+
issues_embedder = None
|
| 162 |
+
assert args.index_issues is True
|
| 163 |
+
if args.index_issues:
|
| 164 |
+
logging.info("Issuing embedding jobs for GitHub issues...")
|
| 165 |
+
issues_manager = GitHubIssuesManager(args.repo_id)
|
| 166 |
+
issues_manager.download()
|
| 167 |
+
logging.info("Embedding GitHub issues...")
|
| 168 |
+
chunker = GitHubIssuesChunker(max_tokens=args.tokens_per_chunk)
|
| 169 |
+
issues_embedder = build_batch_embedder_from_flags(issues_manager, chunker, args)
|
| 170 |
+
issues_embedder.embed_dataset(args.chunks_per_batch, args.max_embedding_jobs)
|
| 171 |
+
|
| 172 |
+
########################
|
| 173 |
+
# Step 2: Vector Store #
|
| 174 |
+
########################
|
| 175 |
|
| 176 |
if args.vector_store_type == "marqo":
|
| 177 |
# Marqo computes embeddings and stores them in the vector store at once, so we're done.
|
| 178 |
logging.info("Done!")
|
| 179 |
return
|
| 180 |
|
| 181 |
+
if repo_embedder is not None:
|
| 182 |
+
logging.info("Waiting for repo embeddings to be ready...")
|
| 183 |
+
while not repo_embedder.embeddings_are_ready():
|
| 184 |
+
logging.info("Sleeping for 30 seconds...")
|
| 185 |
+
time.sleep(30)
|
| 186 |
+
|
| 187 |
+
logging.info("Moving embeddings to the repo vector store...")
|
| 188 |
+
repo_vector_store = build_from_args(args)
|
| 189 |
+
repo_vector_store.ensure_exists()
|
| 190 |
+
repo_vector_store.upsert(repo_embedder.download_embeddings())
|
| 191 |
+
|
| 192 |
+
if issues_embedder is not None:
|
| 193 |
+
logging.info("Waiting for issue embeddings to be ready...")
|
| 194 |
+
while not issues_embedder.embeddings_are_ready():
|
| 195 |
+
logging.info("Sleeping for 30 seconds...")
|
| 196 |
+
time.sleep(30)
|
| 197 |
+
|
| 198 |
+
logging.info("Moving embeddings to the issues vector store...")
|
| 199 |
+
issues_vector_store = build_from_args(args)
|
| 200 |
+
issues_vector_store.ensure_exists()
|
| 201 |
+
issues_vector_store.upsert(issues_embedder.download_embeddings())
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
logging.info("Done!")
|
| 204 |
|
| 205 |
|