Spaces:
Running
Running
Add official documentation instead of README (#93)
Browse files* Add Anthropic reranker. Switch to "remote" config by default.
* Update README to point to docs.
* Format tests.
- README.md +13 -242
- sage/config.py +2 -2
- sage/configs/remote.yaml +6 -8
- sage/reranker.py +64 -2
- tests/test_data_manager.py +47 -56
- tests/test_github.py +83 -50
README.md
CHANGED
|
@@ -8,6 +8,11 @@
|
|
| 8 |
<a href="https://github.com/Storia-AI/sage/stargazers" target=="_blank"><img alt="GitHub Repo stars" src="https://img.shields.io/github/stars/Storia-AI/sage?logo=github&link=https%3A%2F%2Fgithub.com%2FStoria-AI%2Fsage%2Fstargazers"></a>
|
| 9 |
<a href="https://github.com/Storia-AI/sage/blob/main/LICENSE" target=="_blank"><img alt="GitHub License" src="https://img.shields.io/github/license/Storia-AI/sage" /></a>
|
| 10 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
<br />
|
| 12 |
<figure>
|
| 13 |
<!-- The <kbd> and <sub> tags are work-arounds for styling, since GitHub doesn't take into account inline styles. Note it might display awkwardly on other Markdown editors. -->
|
|
@@ -16,253 +21,19 @@
|
|
| 16 |
</figure>
|
| 17 |
</div>
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
## Installation
|
| 22 |
-
|
| 23 |
-
<details open>
|
| 24 |
-
<summary><strong>Using pipx (for regular users) </strong></summary>
|
| 25 |
-
Make sure pipx is installed on your system (see <a href="https://pipx.pypa.io/stable/installation/">instructions</a>), then run:
|
| 26 |
-
|
| 27 |
-
```
|
| 28 |
-
pipx install git+https://github.com/Storia-AI/sage.git@main
|
| 29 |
-
```
|
| 30 |
-
|
| 31 |
-
</details>
|
| 32 |
-
|
| 33 |
-
<details>
|
| 34 |
-
<summary><strong>Using venv and pip (for contributors)</strong></summary>
|
| 35 |
-
Alternatively, you can manually create a virtual environment and install Code Sage via pip:
|
| 36 |
-
|
| 37 |
-
```
|
| 38 |
-
python -m venv sage-venv
|
| 39 |
-
source sage-venv/bin/activate
|
| 40 |
-
git clone https://github.com/Storia-AI/sage.git
|
| 41 |
-
cd sage
|
| 42 |
-
pip install -e .
|
| 43 |
-
```
|
| 44 |
-
|
| 45 |
-
</details>
|
| 46 |
-
|
| 47 |
-
## Prerequisites
|
| 48 |
-
|
| 49 |
-
`sage` performs two steps:
|
| 50 |
-
|
| 51 |
-
1. Indexes your codebase (requiring an embedder and a vector store)
|
| 52 |
-
2. Enables chatting via LLM + RAG (requiring access to an LLM)
|
| 53 |
-
|
| 54 |
-
<details open>
|
| 55 |
-
<summary><strong>:computer: Running locally (lower quality)</strong></summary>
|
| 56 |
-
|
| 57 |
-
1. To index the codebase locally, we use the open-source project <a href="https://github.com/marqo-ai/marqo">Marqo</a>, which is both an embedder and a vector store. To bring up a Marqo instance:
|
| 58 |
-
|
| 59 |
-
```
|
| 60 |
-
docker rm -f marqo
|
| 61 |
-
docker pull marqoai/marqo:latest
|
| 62 |
-
docker run --name marqo -it -p 8882:8882 marqoai/marqo:latest
|
| 63 |
-
```
|
| 64 |
-
|
| 65 |
-
This will open a persistent Marqo console window. This should take around 2-3 minutes on a fresh install.
|
| 66 |
-
|
| 67 |
-
2. To chat with an LLM locally, we use <a href="https://github.com/ollama/ollama">Ollama</a>:
|
| 68 |
-
|
| 69 |
-
- Head over to [ollama.com](https://ollama.com) to download the appropriate binary for your machine.
|
| 70 |
-
- Open a new terminal window
|
| 71 |
-
- Pull the desired model, e.g. `ollama pull llama3.1`.
|
| 72 |
-
|
| 73 |
-
</details>
|
| 74 |
-
|
| 75 |
-
<details>
|
| 76 |
-
<summary><strong>:cloud: Using external providers (higher quality)</strong></summary>
|
| 77 |
-
|
| 78 |
-
1. For embeddings, we support <a href="https://platform.openai.com/docs/guides/embeddings">OpenAI</a> and <a href="https://docs.voyageai.com/docs/embeddings">Voyage</a>. According to [our experiments](benchmarks/retrieval/README.md), OpenAI is better quality. Their batch API is also faster, with more generous rate limits. Export the API key of the desired provider:
|
| 79 |
-
|
| 80 |
-
```
|
| 81 |
-
export OPENAI_API_KEY=... # or
|
| 82 |
-
export VOYAGE_API_KEY=...
|
| 83 |
-
```
|
| 84 |
-
|
| 85 |
-
2. We use <a href="https://www.pinecone.io/">Pinecone</a> for the vector store, so you will need an API key:
|
| 86 |
-
|
| 87 |
-
```
|
| 88 |
-
export PINECONE_API_KEY=...
|
| 89 |
-
```
|
| 90 |
-
If you want to reuse an existing Pinecone index, specify it. Otherwise we'll create a new one called `sage`.
|
| 91 |
-
```
|
| 92 |
-
export PINECONE_INDEX_NAME=...
|
| 93 |
-
```
|
| 94 |
-
|
| 95 |
-
3. For reranking, we support <a href="https://developer.nvidia.com/blog/enhancing-rag-pipelines-with-re-ranking/">NVIDIA</a>, <a href="https://docs.voyageai.com/docs/reranker">Voyage</a>, <a href="https://cohere.com/rerank">Cohere</a>, and <a href="https://jina.ai/reranker/">Jina</a>.
|
| 96 |
-
- According to [our experiments](benchmark/retrieval/README.md), NVIDIA performs best. To get an API key, follow [these instructions](https://docs.nvidia.com/nim/large-language-models/latest/getting-started.html#generate-an-api-key). Note that NVIDIA's API keys are model-specific. We recommend using `nvidia/nv-rerankqa-mistral-4b-v3`.
|
| 97 |
-
- Export the API key of the desired provider:
|
| 98 |
-
```
|
| 99 |
-
export NVIDIA_API_KEY=... # or
|
| 100 |
-
export VOYAGE_API_KEY=... # or
|
| 101 |
-
export COHERE_API_KEY=... # or
|
| 102 |
-
export JINA_API_KEY=...
|
| 103 |
-
```
|
| 104 |
-
|
| 105 |
-
4. For chatting with an LLM, we support OpenAI and Anthropic. For the latter, set an additional API key:
|
| 106 |
-
```
|
| 107 |
-
export ANTHROPIC_API_KEY=...
|
| 108 |
-
```
|
| 109 |
-
|
| 110 |
-
For easier configuration, adapt the entries within the sample `.sage-env` (change the API keys names based on your desired setup) and run:
|
| 111 |
-
```
|
| 112 |
-
source .sage-env
|
| 113 |
-
```
|
| 114 |
-
</details>
|
| 115 |
-
|
| 116 |
-
### Optional
|
| 117 |
-
If you are planning on indexing GitHub issues in addition to the codebase, you will need a GitHub token:
|
| 118 |
-
|
| 119 |
-
export GITHUB_TOKEN=...
|
| 120 |
-
|
| 121 |
-
## Running it
|
| 122 |
-
|
| 123 |
-
1. Select your desired repository:
|
| 124 |
-
```
|
| 125 |
-
export GITHUB_REPO=huggingface/transformers
|
| 126 |
-
```
|
| 127 |
-
|
| 128 |
-
2. Index the repository. This might take a few minutes, depending on its size.
|
| 129 |
-
```
|
| 130 |
-
sage-index $GITHUB_REPO
|
| 131 |
-
```
|
| 132 |
-
To use external providers instead of running locally, set `--mode=remote`.
|
| 133 |
-
|
| 134 |
-
3. Chat with the repository, once it's indexed:
|
| 135 |
-
```
|
| 136 |
-
sage-chat $GITHUB_REPO
|
| 137 |
-
```
|
| 138 |
-
To use external providers instead of running locally, set `--mode=remote`.
|
| 139 |
-
</details>
|
| 140 |
-
|
| 141 |
-
### Notes:
|
| 142 |
-
- To get a public URL for your chat app, set `--share=true`.
|
| 143 |
-
- You can overwrite the default settings (e.g. desired embedding model or LLM) via command line flags. Run `sage-index --help` or `sage-chat --help` for a full list.
|
| 144 |
-
|
| 145 |
-
## Additional features
|
| 146 |
-
|
| 147 |
-
<details>
|
| 148 |
-
<summary><strong>:lock: Working with private repositories</strong></summary>
|
| 149 |
-
|
| 150 |
-
To index and chat with a private repository, simply set the `GITHUB_TOKEN` environment variable. To obtain this token, go to github.com > click on your profile icon > Settings > Developer settings > Personal access tokens. You can either make a fine-grained token for the desired repository, or a classic token.
|
| 151 |
-
|
| 152 |
-
```
|
| 153 |
-
export GITHUB_TOKEN=...
|
| 154 |
-
```
|
| 155 |
-
|
| 156 |
-
</details>
|
| 157 |
-
|
| 158 |
-
<details>
|
| 159 |
-
<summary><strong>:hammer_and_wrench: Control which files get indexed</strong></summary>
|
| 160 |
-
|
| 161 |
-
You can specify an inclusion or exclusion file in the following format:
|
| 162 |
-
```
|
| 163 |
-
# This is a comment
|
| 164 |
-
ext:.my-ext-1
|
| 165 |
-
ext:.my-ext-2
|
| 166 |
-
ext:.my-ext-3
|
| 167 |
-
dir:my-dir-1
|
| 168 |
-
dir:my-dir-2
|
| 169 |
-
dir:my-dir-3
|
| 170 |
-
file:my-file-1.md
|
| 171 |
-
file:my-file-2.py
|
| 172 |
-
file:my-file-3.cpp
|
| 173 |
-
```
|
| 174 |
-
where:
|
| 175 |
-
- `ext` specifies a file extension
|
| 176 |
-
- `dir` specifies a directory. This is not a full path. For instance, if you specify `dir:tests` in an exclusion directory, then a file like `/path/to/my/tests/file.py` will be ignored.
|
| 177 |
-
- `file` specifies a file name. This is also not a full path. For instance, if you specify `file:__init__.py`, then a file like `/path/to/my/__init__.py` will be ignored.
|
| 178 |
-
|
| 179 |
-
To specify an inclusion file (i.e. only index the specified files):
|
| 180 |
-
```
|
| 181 |
-
sage-index $GITHUB_REPO --include=/path/to/inclusion/file
|
| 182 |
-
```
|
| 183 |
-
|
| 184 |
-
To specify an exclusion file (i.e. index all files, except for the ones specified):
|
| 185 |
-
```
|
| 186 |
-
sage-index $GITHUB_REPO --exclude=/path/to/exclusion/file
|
| 187 |
-
```
|
| 188 |
-
By default, we use the exclusion file [sample-exclude.txt](sage/sample-exclude.txt).
|
| 189 |
-
|
| 190 |
-
</details>
|
| 191 |
-
|
| 192 |
-
<details>
|
| 193 |
-
<summary><strong>:bug: Index open GitHub issues</strong></summary>
|
| 194 |
-
|
| 195 |
-
You will need a GitHub token first:
|
| 196 |
-
|
| 197 |
-
```
|
| 198 |
-
export GITHUB_TOKEN=...
|
| 199 |
-
```
|
| 200 |
-
|
| 201 |
-
To index GitHub issues without comments:
|
| 202 |
-
```
|
| 203 |
-
sage-index $GITHUB_REPO --index-issues
|
| 204 |
-
```
|
| 205 |
-
|
| 206 |
-
To index GitHub issues with comments:
|
| 207 |
-
```
|
| 208 |
-
sage-index $GITHUB_REPO --index-issues --index-issue-comments
|
| 209 |
-
```
|
| 210 |
-
|
| 211 |
-
To index GitHub issues, but not the codebase:
|
| 212 |
-
```
|
| 213 |
-
sage-index $GITHUB_REPO --index-issues --no-index-repo
|
| 214 |
-
```
|
| 215 |
-
|
| 216 |
-
</details>
|
| 217 |
-
|
| 218 |
-
<details>
|
| 219 |
-
<summary><strong>:books: Experiment with retrieval strategies</strong></summary>
|
| 220 |
-
|
| 221 |
-
Retrieving the right files from the vector database is arguably the quality bottleneck of the system. We are actively experimenting with various retrieval strategies and documenting our findings [here](benchmark/retrieval/README.md).
|
| 222 |
-
|
| 223 |
-
Currently, we support the following types of retrieval:
|
| 224 |
-
- **Vanilla RAG** from a vector database (nearest neighbor between dense embeddings). This is the default.
|
| 225 |
-
- **Hybrid RAG** that combines dense retrieval (embeddings-based) with sparse retrieval (BM25). Use `--retrieval-alpha` to weigh the two strategies.
|
| 226 |
-
|
| 227 |
-
- A value of 1 means dense-only retrieval and 0 means BM25-only retrieval.
|
| 228 |
-
- Note this is not available when running locally, only when using Pinecone as a vector store.
|
| 229 |
-
- Contrary to [Anthropic's findings](https://www.anthropic.com/news/contextual-retrieval), we find that BM25 is actually damaging performance *on codebases*, because it gives undeserved advantage to Markdown files.
|
| 230 |
-
|
| 231 |
-
- **Multi-query retrieval** performs multiple query rewrites, makes a separate retrieval call for each, and takes the union of the retrieved documents. You can activate it by passing `--multi-query-retrieval`. This can be combined with both vanilla and hybrid RAG.
|
| 232 |
-
|
| 233 |
-
- We find that [on our benchmark](benchmark/retrieval/README.md) this only marginally improves retrieval quality (from 0.44 to 0.46 R-precision) while being significantly slower and more expensive due to LLM calls. But your mileage may vary.
|
| 234 |
-
|
| 235 |
-
- **LLM-only retrieval** completely circumvents indexing the codebase. We simply enumerate all file paths and pass them to an LLM together with the user query. We ask the LLM which files are likely to be relevant for the user query, solely based on their filenames. You can activate it by passing `--llm-retriever`.
|
| 236 |
-
|
| 237 |
-
- We find that [on our benchmark](benchmark/retrieval/README.md) the performance is comparable with vector database solutions (R-precision is 0.44 for both). This is quite remarkable, since we've saved so much effort by not indexing the codebase. However, we are reluctant to claim that these findings generalize, for the following reasons:
|
| 238 |
-
- Our (artificial) dataset occasionally contains explicit path names in the query, making it trivial for the LLM. Sample query: *"Alice is managing a series of machine learning experiments. Please explain in detail how `main` in `examples/pytorch/image-pretraining/run_mim.py` allows her to organize the outputs of each experiment in separate directories."*
|
| 239 |
-
- Our benchmark focuses on the Transformers library, which is well-maintained and the file paths are often meaningful. This might not be the case for all codebases.
|
| 240 |
-
|
| 241 |
-
</details>
|
| 242 |
-
|
| 243 |
-
# Why chat with a codebase?
|
| 244 |
-
|
| 245 |
-
Sometimes you just want to learn how a codebase works and how to integrate it, without spending hours sifting through
|
| 246 |
-
the code itself.
|
| 247 |
-
|
| 248 |
-
`sage` is like an open-source GitHub Copilot with the most up-to-date information about your repo.
|
| 249 |
-
|
| 250 |
-
Features:
|
| 251 |
-
|
| 252 |
-
- **Dead-simple set-up.** Run *two scripts* and you have a functional chat interface for your code. That's really it.
|
| 253 |
-
- **Heavily documented answers.** Every response shows where in the code the context for the answer was pulled from. Let's build trust in the AI.
|
| 254 |
-
- **Runs locally or on the cloud.**
|
| 255 |
-
- **Plug-and-play.** Want to improve the algorithms powering the code understanding/generation? We've made every component of the pipeline easily swappable. Google-grade engineering standards allow you to customize to your heart's content.
|
| 256 |
|
| 257 |
-
|
| 258 |
|
| 259 |
-
|
| 260 |
-
-
|
| 261 |
-
-
|
|
|
|
|
|
|
| 262 |
|
| 263 |
# Want your repository hosted?
|
| 264 |
|
| 265 |
-
We're working to make all code on the internet searchable and understandable for devs. You can check out
|
| 266 |
|
| 267 |
If you're the maintainer of an OSS repo and would like a dedicated page on Code Sage (e.g. `sage.storia.ai/your-repo`), then send us a message at [founders@storia.ai](mailto:founders@storia.ai). We'll do it for free!
|
| 268 |
|
|
|
|
| 8 |
<a href="https://github.com/Storia-AI/sage/stargazers" target=="_blank"><img alt="GitHub Repo stars" src="https://img.shields.io/github/stars/Storia-AI/sage?logo=github&link=https%3A%2F%2Fgithub.com%2FStoria-AI%2Fsage%2Fstargazers"></a>
|
| 9 |
<a href="https://github.com/Storia-AI/sage/blob/main/LICENSE" target=="_blank"><img alt="GitHub License" src="https://img.shields.io/github/license/Storia-AI/sage" /></a>
|
| 10 |
</div>
|
| 11 |
+
<div>
|
| 12 |
+
<a href="https://sage-docs.storia.ai">Documentation</a>
|
| 13 |
+
<span>·</span>
|
| 14 |
+
<a href="https://sage.storia.ai">Hosted app</a>
|
| 15 |
+
</div>
|
| 16 |
<br />
|
| 17 |
<figure>
|
| 18 |
<!-- The <kbd> and <sub> tags are work-arounds for styling, since GitHub doesn't take into account inline styles. Note it might display awkwardly on other Markdown editors. -->
|
|
|
|
| 21 |
</figure>
|
| 22 |
</div>
|
| 23 |
|
| 24 |
+
***
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
**Sage** is like an open-source GitHub Copilot that helps you learn how a codebase works and how to integrate it into your project without spending hours sifting through the code.
|
| 27 |
|
| 28 |
+
# Main features
|
| 29 |
+
- **Dead-simple setup**. Follow our [quickstart guide](https://sage-docs.storia.ai/quickstart) to get started.
|
| 30 |
+
- **Runs locally or on the cloud**. When privacy is your priority, you can run the entire pipeline locally using [Ollama](https://ollama.com) for LLMs and [Marqo](https://github.com/marqo-ai/marqo) as a vector store. When optimizing for quality, you can use third-party LLM providers like OpenAI and Anthropic.
|
| 31 |
+
- **Wide range of built-in retrieval mechanisms**. We support both lightweight retrieval strategies (with nothing more but an LLM API key required) and more traditional RAG (which requires indexing the codebase). There are many knobs you can tune for retrieval to work well on your codebase.
|
| 32 |
+
- **Well-documented experiments**. We profile various strategies (for embeddings, retrieval etc.) on our own benchmark and thoroughly [document the results](benchmarks/retrieval/README.md).
|
| 33 |
|
| 34 |
# Want your repository hosted?
|
| 35 |
|
| 36 |
+
We're working to make all code on the internet searchable and understandable for devs. You can check out [hosted app](https://sage.storia.ai). We pre-indexed a slew of OSS repos, and you can index your desired ones by simply pasting a GitHub URL.
|
| 37 |
|
| 38 |
If you're the maintainer of an OSS repo and would like a dedicated page on Code Sage (e.g. `sage.storia.ai/your-repo`), then send us a message at [founders@storia.ai](mailto:founders@storia.ai). We'll do it for free!
|
| 39 |
|
sage/config.py
CHANGED
|
@@ -58,8 +58,8 @@ def add_config_args(parser: ArgumentParser):
|
|
| 58 |
parser.add(
|
| 59 |
"--mode",
|
| 60 |
choices=["local", "remote"],
|
| 61 |
-
default="
|
| 62 |
-
help="Whether to use local-only resources or call third-party providers.",
|
| 63 |
)
|
| 64 |
parser.add(
|
| 65 |
"--config",
|
|
|
|
| 58 |
parser.add(
|
| 59 |
"--mode",
|
| 60 |
choices=["local", "remote"],
|
| 61 |
+
default="remote",
|
| 62 |
+
help="Whether to use local-only resources or call third-party providers (remote).",
|
| 63 |
)
|
| 64 |
parser.add(
|
| 65 |
"--config",
|
sage/configs/remote.yaml
CHANGED
|
@@ -1,17 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Embeddings
|
| 2 |
embedding-provider: openai
|
| 3 |
embedding-model: text-embedding-3-small
|
| 4 |
tokens-per-chunk: 800
|
| 5 |
chunks-per-batch: 2000
|
| 6 |
-
|
| 7 |
# Vector store
|
| 8 |
vector-store-provider: pinecone
|
| 9 |
pinecone-index-name: sage
|
| 10 |
hybrid-retrieval: true
|
| 11 |
-
|
| 12 |
-
# LLM
|
| 13 |
-
llm-provider: openai
|
| 14 |
-
llm-model: gpt-4o
|
| 15 |
-
|
| 16 |
-
# Reranking
|
| 17 |
-
reranker-provider: nvidia
|
|
|
|
| 1 |
+
llm-retriever: true
|
| 2 |
+
llm-provider: anthropic
|
| 3 |
+
reranker-provider: anthropic
|
| 4 |
+
|
| 5 |
+
# The settings below (embeddings and vector store) are only relevant when setting --no-llm-retriever
|
| 6 |
+
|
| 7 |
# Embeddings
|
| 8 |
embedding-provider: openai
|
| 9 |
embedding-model: text-embedding-3-small
|
| 10 |
tokens-per-chunk: 800
|
| 11 |
chunks-per-batch: 2000
|
|
|
|
| 12 |
# Vector store
|
| 13 |
vector-store-provider: pinecone
|
| 14 |
pinecone-index-name: sage
|
| 15 |
hybrid-retrieval: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sage/reranker.py
CHANGED
|
@@ -1,14 +1,21 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
from enum import Enum
|
| 3 |
-
from typing import Optional
|
| 4 |
|
| 5 |
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
| 6 |
from langchain_cohere import CohereRerank
|
| 7 |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
| 8 |
from langchain_community.document_compressors import JinaRerank
|
| 9 |
-
from langchain_core.
|
|
|
|
|
|
|
|
|
|
| 10 |
from langchain_nvidia_ai_endpoints import NVIDIARerank
|
| 11 |
from langchain_voyageai import VoyageAIRerank
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class RerankerProvider(Enum):
|
|
@@ -18,6 +25,58 @@ class RerankerProvider(Enum):
|
|
| 18 |
NVIDIA = "nvidia"
|
| 19 |
JINA = "jina"
|
| 20 |
VOYAGE = "voyage"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
def build_reranker(provider: str, model: Optional[str] = None, top_k: Optional[int] = 5) -> BaseDocumentCompressor:
|
|
@@ -46,4 +105,7 @@ def build_reranker(provider: str, model: Optional[str] = None, top_k: Optional[i
|
|
| 46 |
raise ValueError("Please set the VOYAGE_API_KEY environment variable")
|
| 47 |
model = model or "rerank-1"
|
| 48 |
return VoyageAIRerank(model=model, api_key=os.environ.get("VOYAGE_API_KEY"), top_k=top_k)
|
|
|
|
|
|
|
|
|
|
| 49 |
raise ValueError(f"Invalid reranker provider: {provider}")
|
|
|
|
| 1 |
+
import logging
|
| 2 |
import os
|
| 3 |
from enum import Enum
|
| 4 |
+
from typing import List, Optional
|
| 5 |
|
| 6 |
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
| 7 |
from langchain_cohere import CohereRerank
|
| 8 |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
| 9 |
from langchain_community.document_compressors import JinaRerank
|
| 10 |
+
from langchain_core.callbacks.manager import Callbacks
|
| 11 |
+
from langchain_core.documents import BaseDocumentCompressor, Document
|
| 12 |
+
from langchain_core.language_models import BaseLanguageModel
|
| 13 |
+
from langchain_core.prompts import PromptTemplate
|
| 14 |
from langchain_nvidia_ai_endpoints import NVIDIARerank
|
| 15 |
from langchain_voyageai import VoyageAIRerank
|
| 16 |
+
from pydantic import ConfigDict, Field
|
| 17 |
+
|
| 18 |
+
from sage.llm import build_llm_via_langchain
|
| 19 |
|
| 20 |
|
| 21 |
class RerankerProvider(Enum):
|
|
|
|
| 25 |
NVIDIA = "nvidia"
|
| 26 |
JINA = "jina"
|
| 27 |
VOYAGE = "voyage"
|
| 28 |
+
# Anthropic doesn't provide an explicit reranker; we simply prompt the LLM with the user query and the content of
|
| 29 |
+
# the top k documents.
|
| 30 |
+
ANTHROPIC = "anthropic"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class LLMReranker(BaseDocumentCompressor):
|
| 34 |
+
"""Reranker that passes the user query and top N documents to a language model to order them.
|
| 35 |
+
|
| 36 |
+
Note that Langchain's RerankLLM does not support LLMs from Anthropic.
|
| 37 |
+
https://python.langchain.com/api_reference/community/document_compressors/langchain_community.document_compressors.rankllm_rerank.RankLLMRerank.html
|
| 38 |
+
Also, they rely on https://github.com/castorini/rank_llm, which doesn't run on Apple Silicon (M1/M2 chips).
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
llm: BaseLanguageModel = Field(...)
|
| 42 |
+
top_k: int = Field(...)
|
| 43 |
+
|
| 44 |
+
model_config = ConfigDict(
|
| 45 |
+
arbitrary_types_allowed=True,
|
| 46 |
+
extra="forbid",
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def prompt(self):
|
| 51 |
+
return PromptTemplate.from_template(
|
| 52 |
+
"Given the following query: '{query}'\n\n"
|
| 53 |
+
"And these documents:\n\n{documents}\n\n"
|
| 54 |
+
"Rank the documents based on their relevance to the query. "
|
| 55 |
+
"Return only the document numbers in order of relevance, separated by commas. For example: 2,5,1,3,4. "
|
| 56 |
+
"Return absolutely nothing else."
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def compress_documents(
|
| 60 |
+
self,
|
| 61 |
+
documents: List[Document],
|
| 62 |
+
query: str,
|
| 63 |
+
callbacks: Optional[Callbacks] = None,
|
| 64 |
+
) -> List[Document]:
|
| 65 |
+
if len(documents) <= self.top_k:
|
| 66 |
+
return documents
|
| 67 |
+
|
| 68 |
+
doc_texts = [f"Document {i+1}:\n{doc.page_content}\n" for i, doc in enumerate(documents)]
|
| 69 |
+
docs_str = "\n".join(doc_texts)
|
| 70 |
+
|
| 71 |
+
llm_input = self.prompt.format(query=query, documents=docs_str)
|
| 72 |
+
result = self.llm.predict(llm_input)
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
ranked_indices = [int(idx) - 1 for idx in result.strip().split(",")][: self.top_k]
|
| 76 |
+
return [documents[i] for i in ranked_indices]
|
| 77 |
+
except ValueError:
|
| 78 |
+
logging.warning("Failed to parse reranker output. Returning original order. LLM responded with: %s", result)
|
| 79 |
+
return documents[: self.top_k]
|
| 80 |
|
| 81 |
|
| 82 |
def build_reranker(provider: str, model: Optional[str] = None, top_k: Optional[int] = 5) -> BaseDocumentCompressor:
|
|
|
|
| 105 |
raise ValueError("Please set the VOYAGE_API_KEY environment variable")
|
| 106 |
model = model or "rerank-1"
|
| 107 |
return VoyageAIRerank(model=model, api_key=os.environ.get("VOYAGE_API_KEY"), top_k=top_k)
|
| 108 |
+
if provider == RerankerProvider.ANTHROPIC.value:
|
| 109 |
+
llm = build_llm_via_langchain("anthropic", model)
|
| 110 |
+
return LLMReranker(llm=llm, top_k=1)
|
| 111 |
raise ValueError(f"Invalid reranker provider: {provider}")
|
tests/test_data_manager.py
CHANGED
|
@@ -14,102 +14,93 @@ Here, one can find a set of unit tests for GitHubRepoManager class from sage.dat
|
|
| 14 |
|
| 15 |
"""
|
| 16 |
|
| 17 |
-
|
| 18 |
import unittest
|
| 19 |
-
from unittest.mock import
|
|
|
|
| 20 |
from sage.data_manager import GitHubRepoManager
|
| 21 |
-
|
|
|
|
| 22 |
class TestGitHubRepoManager(unittest.TestCase):
|
| 23 |
-
@patch(
|
| 24 |
def test_download_clone_success(self, mock_clone):
|
| 25 |
"""Test the download() method of GitHubRepoManager by mocking the cloning process."""
|
| 26 |
-
repo_manager = GitHubRepoManager(
|
| 27 |
-
|
| 28 |
-
local_dir='/tmp/test_repo'
|
| 29 |
-
)
|
| 30 |
-
mock_clone.return_value = MagicMock()
|
| 31 |
result = repo_manager.download()
|
| 32 |
mock_clone.assert_called_once_with(
|
| 33 |
-
|
| 34 |
-
'/tmp/test_repo/Storia-AI/sage',
|
| 35 |
-
depth=1,
|
| 36 |
-
single_branch=True
|
| 37 |
)
|
| 38 |
self.assertTrue(result)
|
| 39 |
|
| 40 |
-
@patch(
|
| 41 |
def test_is_public_repository(self, mock_get):
|
| 42 |
"""Test the is_public property to check if a repository is public."""
|
| 43 |
mock_get.return_value.status_code = 200
|
| 44 |
-
repo_manager = GitHubRepoManager(repo_id=
|
| 45 |
self.assertTrue(repo_manager.is_public)
|
| 46 |
-
mock_get.assert_called_once_with(
|
| 47 |
|
| 48 |
-
@patch(
|
| 49 |
def test_is_private_repository(self, mock_get):
|
| 50 |
"""Test the is_public property to check if a repository is private."""
|
| 51 |
-
mock_get.return_value.status_code = 404
|
| 52 |
-
repo_manager = GitHubRepoManager(repo_id=
|
| 53 |
self.assertFalse(repo_manager.is_public)
|
| 54 |
-
mock_get.assert_called_once_with(
|
| 55 |
|
| 56 |
-
@patch(
|
| 57 |
def test_default_branch(self, mock_get):
|
| 58 |
"""Test the default_branch property to fetch the default branch of the repository."""
|
| 59 |
mock_get.return_value.status_code = 200
|
| 60 |
-
mock_get.return_value.json.return_value = {
|
| 61 |
-
repo_manager = GitHubRepoManager(repo_id=
|
| 62 |
-
self.assertEqual(repo_manager.default_branch,
|
| 63 |
mock_get.assert_called_once_with(
|
| 64 |
-
|
| 65 |
-
headers={'Accept': 'application/vnd.github.v3+json'}
|
| 66 |
)
|
| 67 |
|
| 68 |
-
@patch(
|
| 69 |
def test_parse_filter_file(self, mock_file):
|
| 70 |
"""Test the _parse_filter_file method for correct parsing of inclusion/exclusion files."""
|
| 71 |
-
repo_manager = GitHubRepoManager(repo_id=
|
| 72 |
-
expected = {
|
| 73 |
-
|
| 74 |
-
'file': ['test.py'],
|
| 75 |
-
'dir': ['test_dir']
|
| 76 |
-
}
|
| 77 |
-
result = repo_manager._parse_filter_file('dummy_path')
|
| 78 |
self.assertEqual(result, expected)
|
| 79 |
|
| 80 |
-
@patch(
|
| 81 |
-
@patch(
|
| 82 |
-
@patch(
|
| 83 |
def test_walk_included_files(self, mock_open, mock_remove, mock_exists):
|
| 84 |
"""Test the walk method to ensure it only includes specified files."""
|
| 85 |
mock_exists.return_value = True
|
| 86 |
-
repo_manager = GitHubRepoManager(
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
included_files = list(repo_manager.walk())
|
| 94 |
print("Included files:", included_files)
|
| 95 |
-
self.assertTrue(any(file[1][
|
|
|
|
| 96 |
def test_read_file(self):
|
| 97 |
"""Test the read_file method to read the content of a file."""
|
| 98 |
-
mock_file_path =
|
| 99 |
-
with patch(
|
| 100 |
-
repo_manager = GitHubRepoManager(repo_id=
|
| 101 |
-
content = repo_manager.read_file(
|
| 102 |
-
self.assertEqual(content,
|
| 103 |
|
| 104 |
-
@patch(
|
| 105 |
def test_create_log_directories(self, mock_makedirs):
|
| 106 |
"""Test that log directories are created."""
|
| 107 |
-
repo_manager = GitHubRepoManager(
|
| 108 |
-
repo_id="Storia-AI/sage",
|
| 109 |
-
local_dir="/tmp/test_repo"
|
| 110 |
-
)
|
| 111 |
|
| 112 |
with self.assertRaises(AttributeError):
|
| 113 |
repo_manager.create_log_directories()
|
| 114 |
-
|
|
|
|
|
|
|
| 115 |
unittest.main()
|
|
|
|
| 14 |
|
| 15 |
"""
|
| 16 |
|
| 17 |
+
import os
|
| 18 |
import unittest
|
| 19 |
+
from unittest.mock import MagicMock, patch
|
| 20 |
+
|
| 21 |
from sage.data_manager import GitHubRepoManager
|
| 22 |
+
|
| 23 |
+
|
| 24 |
class TestGitHubRepoManager(unittest.TestCase):
|
| 25 |
+
@patch("git.Repo.clone_from")
|
| 26 |
def test_download_clone_success(self, mock_clone):
|
| 27 |
"""Test the download() method of GitHubRepoManager by mocking the cloning process."""
|
| 28 |
+
repo_manager = GitHubRepoManager(repo_id="Storia-AI/sage", local_dir="/tmp/test_repo")
|
| 29 |
+
mock_clone.return_value = MagicMock()
|
|
|
|
|
|
|
|
|
|
| 30 |
result = repo_manager.download()
|
| 31 |
mock_clone.assert_called_once_with(
|
| 32 |
+
"https://github.com/Storia-AI/sage.git", "/tmp/test_repo/Storia-AI/sage", depth=1, single_branch=True
|
|
|
|
|
|
|
|
|
|
| 33 |
)
|
| 34 |
self.assertTrue(result)
|
| 35 |
|
| 36 |
+
@patch("sage.data_manager.requests.get")
|
| 37 |
def test_is_public_repository(self, mock_get):
|
| 38 |
"""Test the is_public property to check if a repository is public."""
|
| 39 |
mock_get.return_value.status_code = 200
|
| 40 |
+
repo_manager = GitHubRepoManager(repo_id="Storia-AI/sage")
|
| 41 |
self.assertTrue(repo_manager.is_public)
|
| 42 |
+
mock_get.assert_called_once_with("https://api.github.com/repos/Storia-AI/sage", timeout=10)
|
| 43 |
|
| 44 |
+
@patch("sage.data_manager.requests.get")
|
| 45 |
def test_is_private_repository(self, mock_get):
|
| 46 |
"""Test the is_public property to check if a repository is private."""
|
| 47 |
+
mock_get.return_value.status_code = 404
|
| 48 |
+
repo_manager = GitHubRepoManager(repo_id="Storia-AI/sage")
|
| 49 |
self.assertFalse(repo_manager.is_public)
|
| 50 |
+
mock_get.assert_called_once_with("https://api.github.com/repos/Storia-AI/sage", timeout=10)
|
| 51 |
|
| 52 |
+
@patch("sage.data_manager.requests.get")
|
| 53 |
def test_default_branch(self, mock_get):
|
| 54 |
"""Test the default_branch property to fetch the default branch of the repository."""
|
| 55 |
mock_get.return_value.status_code = 200
|
| 56 |
+
mock_get.return_value.json.return_value = {"default_branch": "main"}
|
| 57 |
+
repo_manager = GitHubRepoManager(repo_id="Storia-AI/sage")
|
| 58 |
+
self.assertEqual(repo_manager.default_branch, "main")
|
| 59 |
mock_get.assert_called_once_with(
|
| 60 |
+
"https://api.github.com/repos/Storia-AI/sage", headers={"Accept": "application/vnd.github.v3+json"}
|
|
|
|
| 61 |
)
|
| 62 |
|
| 63 |
+
@patch("builtins.open", new_callable=unittest.mock.mock_open, read_data="ext:.py\nfile:test.py\ndir:test_dir\n")
|
| 64 |
def test_parse_filter_file(self, mock_file):
|
| 65 |
"""Test the _parse_filter_file method for correct parsing of inclusion/exclusion files."""
|
| 66 |
+
repo_manager = GitHubRepoManager(repo_id="Storia-AI/sage", inclusion_file="dummy_path")
|
| 67 |
+
expected = {"ext": [".py"], "file": ["test.py"], "dir": ["test_dir"]}
|
| 68 |
+
result = repo_manager._parse_filter_file("dummy_path")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
self.assertEqual(result, expected)
|
| 70 |
|
| 71 |
+
@patch("os.path.exists")
|
| 72 |
+
@patch("os.remove")
|
| 73 |
+
@patch("builtins.open", new_callable=unittest.mock.mock_open, read_data="dummy content")
|
| 74 |
def test_walk_included_files(self, mock_open, mock_remove, mock_exists):
|
| 75 |
"""Test the walk method to ensure it only includes specified files."""
|
| 76 |
mock_exists.return_value = True
|
| 77 |
+
repo_manager = GitHubRepoManager(repo_id="Storia-AI/sage", local_dir="/tmp/test_repo")
|
| 78 |
+
with patch(
|
| 79 |
+
"os.walk",
|
| 80 |
+
return_value=[
|
| 81 |
+
("/tmp/test_repo", ("subdir",), ("included_file.py", "excluded_file.txt")),
|
| 82 |
+
],
|
| 83 |
+
):
|
| 84 |
included_files = list(repo_manager.walk())
|
| 85 |
print("Included files:", included_files)
|
| 86 |
+
self.assertTrue(any(file[1]["file_path"] == "included_file.py" for file in included_files))
|
| 87 |
+
|
| 88 |
def test_read_file(self):
|
| 89 |
"""Test the read_file method to read the content of a file."""
|
| 90 |
+
mock_file_path = "/tmp/test_repo/test_file.txt"
|
| 91 |
+
with patch("builtins.open", new_callable=unittest.mock.mock_open, read_data="Hello, World!"):
|
| 92 |
+
repo_manager = GitHubRepoManager(repo_id="Storia-AI/sage", local_dir="/tmp/test_repo")
|
| 93 |
+
content = repo_manager.read_file("test_file.txt")
|
| 94 |
+
self.assertEqual(content, "Hello, World!")
|
| 95 |
|
| 96 |
+
@patch("os.makedirs")
|
| 97 |
def test_create_log_directories(self, mock_makedirs):
|
| 98 |
"""Test that log directories are created."""
|
| 99 |
+
repo_manager = GitHubRepoManager(repo_id="Storia-AI/sage", local_dir="/tmp/test_repo")
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
with self.assertRaises(AttributeError):
|
| 102 |
repo_manager.create_log_directories()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if __name__ == "main":
|
| 106 |
unittest.main()
|
tests/test_github.py
CHANGED
|
@@ -9,13 +9,22 @@ pip install pytest
|
|
| 9 |
pip install pytest-mock
|
| 10 |
"""
|
| 11 |
|
|
|
|
|
|
|
| 12 |
import pytest
|
| 13 |
import requests
|
| 14 |
-
|
| 15 |
-
from sage.github import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
class TestGitHubIssuesManager:
|
| 18 |
-
|
| 19 |
@pytest.fixture(autouse=True)
|
| 20 |
def setup_method(self):
|
| 21 |
"""Fixture to create a GitHubIssuesManager instance for each test."""
|
|
@@ -24,33 +33,37 @@ class TestGitHubIssuesManager:
|
|
| 24 |
@staticmethod
|
| 25 |
def mock_issue_response():
|
| 26 |
"""A mock response for GitHub issues."""
|
| 27 |
-
return MagicMock(
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
| 38 |
@staticmethod
|
| 39 |
def mock_comment_response():
|
| 40 |
"""Create a mock response for GitHub issue comments."""
|
| 41 |
-
return MagicMock(
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
@patch(
|
| 50 |
def test_download_issues(self, mock_get):
|
| 51 |
"""Test the download of issues from GitHub."""
|
| 52 |
mock_get.side_effect = [self.mock_issue_response(), self.mock_comment_response()]
|
| 53 |
-
|
| 54 |
self.github_manager.download()
|
| 55 |
|
| 56 |
assert len(self.github_manager.issues) == 1
|
|
@@ -58,21 +71,27 @@ class TestGitHubIssuesManager:
|
|
| 58 |
assert self.github_manager.issues[0].body == "I'm having a problem with this."
|
| 59 |
assert self.github_manager.issues[0].url == "https://api.github.com/repos/random/random-repo/issues/1"
|
| 60 |
|
| 61 |
-
@patch(
|
| 62 |
def test_walk_issues(self, mock_get):
|
| 63 |
"""Test the walking through downloaded issues."""
|
| 64 |
self.github_manager.issues = [
|
| 65 |
GitHubIssue(url="issue_url", html_url="html_issue_url", title="Test Issue", body="Test Body", comments=[]),
|
| 66 |
-
GitHubIssue(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
]
|
| 68 |
|
| 69 |
issues = list(self.github_manager.walk())
|
| 70 |
-
|
| 71 |
assert len(issues) == 2
|
| 72 |
assert issues[0][0].title == "Test Issue"
|
| 73 |
assert issues[1][0].title == "Another Test Issue"
|
| 74 |
|
| 75 |
-
@patch(
|
| 76 |
def test_get_page_of_issues(self, mock_get):
|
| 77 |
"""Test fetching a page of issues."""
|
| 78 |
mock_response = MagicMock()
|
|
@@ -80,23 +99,25 @@ class TestGitHubIssuesManager:
|
|
| 80 |
{
|
| 81 |
"url": "https://api.github.com/repos/random/random-repo/issues/1",
|
| 82 |
"html_url": "https://github.com/random/random-repo/issues/1",
|
| 83 |
-
"title": "Found a bug",
|
| 84 |
"body": "I'm having a problem with this.",
|
| 85 |
"comments_url": "https://api.github.com/repos/random/random-repo/issues/1/comments",
|
| 86 |
-
"comments": 2
|
| 87 |
}
|
| 88 |
]
|
| 89 |
mock_get.return_value = mock_response
|
| 90 |
|
| 91 |
-
issues = self.github_manager._get_page_of_issues(
|
|
|
|
|
|
|
| 92 |
|
| 93 |
-
assert len(issues) == 1
|
| 94 |
|
| 95 |
-
@patch(
|
| 96 |
def test_get_comments(self, mock_get):
|
| 97 |
"""Test retrieving comments for an issue."""
|
| 98 |
mock_get.return_value.json.return_value = self.mock_comment_response().json()
|
| 99 |
-
|
| 100 |
comments = self.github_manager._get_comments("comments_url")
|
| 101 |
assert len(comments) == 1
|
| 102 |
assert comments[0].body == "This is a comment."
|
|
@@ -111,7 +132,7 @@ class TestGitHubIssuesManager:
|
|
| 111 |
comments=[
|
| 112 |
GitHubIssueComment(url="comment_url_1", html_url="html_comment_url_1", body="First comment."),
|
| 113 |
GitHubIssueComment(url="comment_url_2", html_url="html_comment_url_2", body="Second comment."),
|
| 114 |
-
]
|
| 115 |
)
|
| 116 |
|
| 117 |
chunker = GitHubIssuesChunker(max_tokens=50)
|
|
@@ -124,7 +145,7 @@ class TestGitHubIssuesManager:
|
|
| 124 |
|
| 125 |
|
| 126 |
class TestGitHubIssueComment:
|
| 127 |
-
|
| 128 |
def test_initialization(self):
|
| 129 |
"""Test the initialization of the GitHubIssueComment class."""
|
| 130 |
comment = GitHubIssueComment(url="comment_url", html_url="html_comment_url", body="Sample comment")
|
|
@@ -139,10 +160,12 @@ class TestGitHubIssueComment:
|
|
| 139 |
|
| 140 |
|
| 141 |
class TestGitHubIssue:
|
| 142 |
-
|
| 143 |
def test_initialization(self):
|
| 144 |
"""Test the initialization of the GitHubIssue class."""
|
| 145 |
-
issue = GitHubIssue(
|
|
|
|
|
|
|
| 146 |
assert issue.url == "issue_url"
|
| 147 |
assert issue.html_url == "html_issue_url"
|
| 148 |
assert issue.title == "Test Issue"
|
|
@@ -151,15 +174,19 @@ class TestGitHubIssue:
|
|
| 151 |
|
| 152 |
def test_pretty_property(self):
|
| 153 |
"""Test the pretty property of the GitHubIssue class."""
|
| 154 |
-
issue = GitHubIssue(
|
|
|
|
|
|
|
| 155 |
assert issue.pretty == "# Issue: Test Issue\nTest Body"
|
| 156 |
|
| 157 |
|
| 158 |
class TestIssueChunk:
|
| 159 |
-
|
| 160 |
def test_initialization(self):
|
| 161 |
"""Test the initialization of the IssueChunk class."""
|
| 162 |
-
issue = GitHubIssue(
|
|
|
|
|
|
|
| 163 |
chunk = IssueChunk(issue=issue, start_comment=0, end_comment=1)
|
| 164 |
assert chunk.issue == issue
|
| 165 |
assert chunk.start_comment == 0
|
|
@@ -167,32 +194,38 @@ class TestIssueChunk:
|
|
| 167 |
|
| 168 |
def test_content_property(self):
|
| 169 |
"""Test the content property of the IssueChunk class."""
|
| 170 |
-
issue = GitHubIssue(
|
|
|
|
|
|
|
| 171 |
chunk = IssueChunk(issue=issue, start_comment=0, end_comment=1)
|
| 172 |
assert chunk.content == "# Issue: Test Issue\nTest Body\n\n"
|
| 173 |
|
| 174 |
def test_metadata_property(self):
|
| 175 |
"""Test the metadata property of the IssueChunk class."""
|
| 176 |
-
issue = GitHubIssue(
|
|
|
|
|
|
|
| 177 |
chunk = IssueChunk(issue=issue, start_comment=0, end_comment=1)
|
| 178 |
expected_metadata = {
|
| 179 |
"id": "html_issue_url_0_1",
|
| 180 |
"url": "html_issue_url",
|
| 181 |
"start_comment": 0,
|
| 182 |
"end_comment": 1,
|
| 183 |
-
|
| 184 |
}
|
| 185 |
assert chunk.metadata == expected_metadata
|
| 186 |
|
| 187 |
def test_num_tokens_property(self):
|
| 188 |
"""Test the num_tokens property of the IssueChunk class."""
|
| 189 |
-
issue = GitHubIssue(
|
|
|
|
|
|
|
| 190 |
chunk = IssueChunk(issue=issue, start_comment=0, end_comment=1)
|
| 191 |
-
assert chunk.num_tokens == 12
|
| 192 |
|
| 193 |
|
| 194 |
class TestGitHubIssuesChunker:
|
| 195 |
-
|
| 196 |
def test_initialization(self):
|
| 197 |
"""Test the initialization of the GitHubIssuesChunker class."""
|
| 198 |
chunker = GitHubIssuesChunker(max_tokens=50)
|
|
@@ -205,16 +238,16 @@ class TestGitHubIssuesChunker:
|
|
| 205 |
html_url="html_issue_url",
|
| 206 |
title="Test Issue",
|
| 207 |
body="This is a long body of the issue that needs to be chunked.",
|
| 208 |
-
comments=[]
|
| 209 |
)
|
| 210 |
-
|
| 211 |
chunker = GitHubIssuesChunker(max_tokens=50)
|
| 212 |
chunks = chunker.chunk(content=issue, metadata={})
|
| 213 |
-
|
| 214 |
assert len(chunks) > 0
|
| 215 |
|
| 216 |
assert all(isinstance(chunk, IssueChunk) for chunk in chunks)
|
| 217 |
|
| 218 |
|
| 219 |
-
if __name__ ==
|
| 220 |
pytest.main()
|
|
|
|
| 9 |
pip install pytest-mock
|
| 10 |
"""
|
| 11 |
|
| 12 |
+
from unittest.mock import MagicMock, patch
|
| 13 |
+
|
| 14 |
import pytest
|
| 15 |
import requests
|
| 16 |
+
|
| 17 |
+
from sage.github import (
|
| 18 |
+
GitHubIssue,
|
| 19 |
+
GitHubIssueComment,
|
| 20 |
+
GitHubIssuesChunker,
|
| 21 |
+
GitHubIssuesManager,
|
| 22 |
+
IssueChunk,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
|
| 26 |
class TestGitHubIssuesManager:
|
| 27 |
+
|
| 28 |
@pytest.fixture(autouse=True)
|
| 29 |
def setup_method(self):
|
| 30 |
"""Fixture to create a GitHubIssuesManager instance for each test."""
|
|
|
|
| 33 |
@staticmethod
|
| 34 |
def mock_issue_response():
|
| 35 |
"""A mock response for GitHub issues."""
|
| 36 |
+
return MagicMock(
|
| 37 |
+
json=lambda: [
|
| 38 |
+
{
|
| 39 |
+
"url": "https://api.github.com/repos/random/random-repo/issues/1",
|
| 40 |
+
"html_url": "https://github.com/random/random-repo/issues/1",
|
| 41 |
+
"title": "Found a bug",
|
| 42 |
+
"body": "I'm having a problem with this.",
|
| 43 |
+
"comments_url": "https://api.github.com/repos/random/random-repo/issues/1/comments",
|
| 44 |
+
"comments": 2,
|
| 45 |
+
}
|
| 46 |
+
]
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
@staticmethod
|
| 50 |
def mock_comment_response():
|
| 51 |
"""Create a mock response for GitHub issue comments."""
|
| 52 |
+
return MagicMock(
|
| 53 |
+
json=lambda: [
|
| 54 |
+
{
|
| 55 |
+
"url": "https://api.github.com/repos/random/random-repo/issues/comments/1",
|
| 56 |
+
"html_url": "https://github.com/random/random-repo/issues/comments/1",
|
| 57 |
+
"body": "This is a comment.",
|
| 58 |
+
}
|
| 59 |
+
]
|
| 60 |
+
)
|
| 61 |
|
| 62 |
+
@patch("github.requests.get")
|
| 63 |
def test_download_issues(self, mock_get):
|
| 64 |
"""Test the download of issues from GitHub."""
|
| 65 |
mock_get.side_effect = [self.mock_issue_response(), self.mock_comment_response()]
|
| 66 |
+
|
| 67 |
self.github_manager.download()
|
| 68 |
|
| 69 |
assert len(self.github_manager.issues) == 1
|
|
|
|
| 71 |
assert self.github_manager.issues[0].body == "I'm having a problem with this."
|
| 72 |
assert self.github_manager.issues[0].url == "https://api.github.com/repos/random/random-repo/issues/1"
|
| 73 |
|
| 74 |
+
@patch("github.requests.get")
|
| 75 |
def test_walk_issues(self, mock_get):
|
| 76 |
"""Test the walking through downloaded issues."""
|
| 77 |
self.github_manager.issues = [
|
| 78 |
GitHubIssue(url="issue_url", html_url="html_issue_url", title="Test Issue", body="Test Body", comments=[]),
|
| 79 |
+
GitHubIssue(
|
| 80 |
+
url="issue_url_2",
|
| 81 |
+
html_url="html_issue_url_2",
|
| 82 |
+
title="Another Test Issue",
|
| 83 |
+
body="Another Test Body",
|
| 84 |
+
comments=[],
|
| 85 |
+
),
|
| 86 |
]
|
| 87 |
|
| 88 |
issues = list(self.github_manager.walk())
|
| 89 |
+
|
| 90 |
assert len(issues) == 2
|
| 91 |
assert issues[0][0].title == "Test Issue"
|
| 92 |
assert issues[1][0].title == "Another Test Issue"
|
| 93 |
|
| 94 |
+
@patch("github.requests.get")
|
| 95 |
def test_get_page_of_issues(self, mock_get):
|
| 96 |
"""Test fetching a page of issues."""
|
| 97 |
mock_response = MagicMock()
|
|
|
|
| 99 |
{
|
| 100 |
"url": "https://api.github.com/repos/random/random-repo/issues/1",
|
| 101 |
"html_url": "https://github.com/random/random-repo/issues/1",
|
| 102 |
+
"title": "Found a bug",
|
| 103 |
"body": "I'm having a problem with this.",
|
| 104 |
"comments_url": "https://api.github.com/repos/random/random-repo/issues/1/comments",
|
| 105 |
+
"comments": 2,
|
| 106 |
}
|
| 107 |
]
|
| 108 |
mock_get.return_value = mock_response
|
| 109 |
|
| 110 |
+
issues = self.github_manager._get_page_of_issues(
|
| 111 |
+
"https://api.github.com/repos/random/random-repo/issues?page=1"
|
| 112 |
+
).json()
|
| 113 |
|
| 114 |
+
assert len(issues) == 1
|
| 115 |
|
| 116 |
+
@patch("github.requests.get")
|
| 117 |
def test_get_comments(self, mock_get):
|
| 118 |
"""Test retrieving comments for an issue."""
|
| 119 |
mock_get.return_value.json.return_value = self.mock_comment_response().json()
|
| 120 |
+
|
| 121 |
comments = self.github_manager._get_comments("comments_url")
|
| 122 |
assert len(comments) == 1
|
| 123 |
assert comments[0].body == "This is a comment."
|
|
|
|
| 132 |
comments=[
|
| 133 |
GitHubIssueComment(url="comment_url_1", html_url="html_comment_url_1", body="First comment."),
|
| 134 |
GitHubIssueComment(url="comment_url_2", html_url="html_comment_url_2", body="Second comment."),
|
| 135 |
+
],
|
| 136 |
)
|
| 137 |
|
| 138 |
chunker = GitHubIssuesChunker(max_tokens=50)
|
|
|
|
| 145 |
|
| 146 |
|
| 147 |
class TestGitHubIssueComment:
|
| 148 |
+
|
| 149 |
def test_initialization(self):
|
| 150 |
"""Test the initialization of the GitHubIssueComment class."""
|
| 151 |
comment = GitHubIssueComment(url="comment_url", html_url="html_comment_url", body="Sample comment")
|
|
|
|
| 160 |
|
| 161 |
|
| 162 |
class TestGitHubIssue:
|
| 163 |
+
|
| 164 |
def test_initialization(self):
|
| 165 |
"""Test the initialization of the GitHubIssue class."""
|
| 166 |
+
issue = GitHubIssue(
|
| 167 |
+
url="issue_url", html_url="html_issue_url", title="Test Issue", body="Test Body", comments=[]
|
| 168 |
+
)
|
| 169 |
assert issue.url == "issue_url"
|
| 170 |
assert issue.html_url == "html_issue_url"
|
| 171 |
assert issue.title == "Test Issue"
|
|
|
|
| 174 |
|
| 175 |
def test_pretty_property(self):
|
| 176 |
"""Test the pretty property of the GitHubIssue class."""
|
| 177 |
+
issue = GitHubIssue(
|
| 178 |
+
url="issue_url", html_url="html_issue_url", title="Test Issue", body="Test Body", comments=[]
|
| 179 |
+
)
|
| 180 |
assert issue.pretty == "# Issue: Test Issue\nTest Body"
|
| 181 |
|
| 182 |
|
| 183 |
class TestIssueChunk:
|
| 184 |
+
|
| 185 |
def test_initialization(self):
|
| 186 |
"""Test the initialization of the IssueChunk class."""
|
| 187 |
+
issue = GitHubIssue(
|
| 188 |
+
url="issue_url", html_url="html_issue_url", title="Test Issue", body="Test Body", comments=[]
|
| 189 |
+
)
|
| 190 |
chunk = IssueChunk(issue=issue, start_comment=0, end_comment=1)
|
| 191 |
assert chunk.issue == issue
|
| 192 |
assert chunk.start_comment == 0
|
|
|
|
| 194 |
|
| 195 |
def test_content_property(self):
|
| 196 |
"""Test the content property of the IssueChunk class."""
|
| 197 |
+
issue = GitHubIssue(
|
| 198 |
+
url="issue_url", html_url="html_issue_url", title="Test Issue", body="Test Body", comments=[]
|
| 199 |
+
)
|
| 200 |
chunk = IssueChunk(issue=issue, start_comment=0, end_comment=1)
|
| 201 |
assert chunk.content == "# Issue: Test Issue\nTest Body\n\n"
|
| 202 |
|
| 203 |
def test_metadata_property(self):
|
| 204 |
"""Test the metadata property of the IssueChunk class."""
|
| 205 |
+
issue = GitHubIssue(
|
| 206 |
+
url="issue_url", html_url="html_issue_url", title="Test Issue", body="Test Body", comments=[]
|
| 207 |
+
)
|
| 208 |
chunk = IssueChunk(issue=issue, start_comment=0, end_comment=1)
|
| 209 |
expected_metadata = {
|
| 210 |
"id": "html_issue_url_0_1",
|
| 211 |
"url": "html_issue_url",
|
| 212 |
"start_comment": 0,
|
| 213 |
"end_comment": 1,
|
| 214 |
+
"text": "# Issue: Test Issue\nTest Body\n\n",
|
| 215 |
}
|
| 216 |
assert chunk.metadata == expected_metadata
|
| 217 |
|
| 218 |
def test_num_tokens_property(self):
|
| 219 |
"""Test the num_tokens property of the IssueChunk class."""
|
| 220 |
+
issue = GitHubIssue(
|
| 221 |
+
url="issue_url", html_url="html_issue_url", title="Test Issue", body="This is a test body.", comments=[]
|
| 222 |
+
)
|
| 223 |
chunk = IssueChunk(issue=issue, start_comment=0, end_comment=1)
|
| 224 |
+
assert chunk.num_tokens == 12
|
| 225 |
|
| 226 |
|
| 227 |
class TestGitHubIssuesChunker:
|
| 228 |
+
|
| 229 |
def test_initialization(self):
|
| 230 |
"""Test the initialization of the GitHubIssuesChunker class."""
|
| 231 |
chunker = GitHubIssuesChunker(max_tokens=50)
|
|
|
|
| 238 |
html_url="html_issue_url",
|
| 239 |
title="Test Issue",
|
| 240 |
body="This is a long body of the issue that needs to be chunked.",
|
| 241 |
+
comments=[],
|
| 242 |
)
|
| 243 |
+
|
| 244 |
chunker = GitHubIssuesChunker(max_tokens=50)
|
| 245 |
chunks = chunker.chunk(content=issue, metadata={})
|
| 246 |
+
|
| 247 |
assert len(chunks) > 0
|
| 248 |
|
| 249 |
assert all(isinstance(chunk, IssueChunk) for chunk in chunks)
|
| 250 |
|
| 251 |
|
| 252 |
+
if __name__ == "__main__":
|
| 253 |
pytest.main()
|