Sanketh Kumar
commited on
Commit
·
df22b26
1
Parent(s):
e600966
chore: added pre-commit-hooks and ruff formatting for commit-hooks
Browse files- .gitignore +2 -1
- .pre-commit-config.yaml +22 -0
- README.md +25 -25
- examples/batch_eval.py +17 -21
- examples/generate_query.py +4 -5
- examples/lightrag_azure_openai_demo.py +1 -1
- examples/lightrag_bedrock_demo.py +4 -9
- examples/lightrag_hf_demo.py +23 -12
- examples/lightrag_ollama_demo.py +15 -10
- examples/lightrag_openai_compatible_demo.py +21 -11
- examples/lightrag_openai_demo.py +14 -8
- lightrag/__init__.py +1 -1
- lightrag/base.py +7 -4
- lightrag/lightrag.py +29 -36
- lightrag/llm.py +144 -79
- lightrag/operate.py +154 -75
- lightrag/prompt.py +4 -10
- lightrag/storage.py +7 -8
- lightrag/utils.py +24 -4
- reproduce/Step_0.py +15 -9
- reproduce/Step_1.py +5 -3
- reproduce/Step_1_openai_compatible.py +17 -12
- reproduce/Step_2.py +11 -9
- reproduce/Step_3.py +21 -8
- reproduce/Step_3_openai_compatible.py +35 -19
- requirements.txt +8 -8
.gitignore
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
__pycache__
|
| 2 |
*.egg-info
|
| 3 |
dickens/
|
| 4 |
-
book.txt
|
|
|
|
|
|
| 1 |
__pycache__
|
| 2 |
*.egg-info
|
| 3 |
dickens/
|
| 4 |
+
book.txt
|
| 5 |
+
lightrag-dev/
|
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 3 |
+
rev: v5.0.0
|
| 4 |
+
hooks:
|
| 5 |
+
- id: trailing-whitespace
|
| 6 |
+
- id: end-of-file-fixer
|
| 7 |
+
- id: requirements-txt-fixer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 11 |
+
rev: v0.6.4
|
| 12 |
+
hooks:
|
| 13 |
+
- id: ruff-format
|
| 14 |
+
- id: ruff
|
| 15 |
+
args: [--fix]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
- repo: https://github.com/mgedmin/check-manifest
|
| 19 |
+
rev: "0.49"
|
| 20 |
+
hooks:
|
| 21 |
+
- id: check-manifest
|
| 22 |
+
stages: [manual]
|
README.md
CHANGED
|
@@ -16,16 +16,16 @@
|
|
| 16 |
<a href="https://pypi.org/project/lightrag-hku/"><img src="https://img.shields.io/pypi/v/lightrag-hku.svg"></a>
|
| 17 |
<a href="https://pepy.tech/project/lightrag-hku"><img src="https://static.pepy.tech/badge/lightrag-hku/month"></a>
|
| 18 |
</p>
|
| 19 |
-
|
| 20 |
This repository hosts the code of LightRAG. The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag).
|
| 21 |

|
| 22 |
</div>
|
| 23 |
|
| 24 |
-
## 🎉 News
|
| 25 |
- [x] [2024.10.18]🎯🎯📢📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
|
| 26 |
- [x] [2024.10.17]🎯🎯📢📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
|
| 27 |
-
- [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
| 28 |
-
- [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
| 29 |
|
| 30 |
## Install
|
| 31 |
|
|
@@ -83,7 +83,7 @@ print(rag.query("What are the top themes in this story?", param=QueryParam(mode=
|
|
| 83 |
<details>
|
| 84 |
<summary> Using Open AI-like APIs </summary>
|
| 85 |
|
| 86 |
-
LightRAG also
|
| 87 |
```python
|
| 88 |
async def llm_model_func(
|
| 89 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
|
@@ -120,7 +120,7 @@ rag = LightRAG(
|
|
| 120 |
|
| 121 |
<details>
|
| 122 |
<summary> Using Hugging Face Models </summary>
|
| 123 |
-
|
| 124 |
If you want to use Hugging Face models, you only need to set LightRAG as follows:
|
| 125 |
```python
|
| 126 |
from lightrag.llm import hf_model_complete, hf_embedding
|
|
@@ -136,7 +136,7 @@ rag = LightRAG(
|
|
| 136 |
embedding_dim=384,
|
| 137 |
max_token_size=5000,
|
| 138 |
func=lambda texts: hf_embedding(
|
| 139 |
-
texts,
|
| 140 |
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
|
| 141 |
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
| 142 |
)
|
|
@@ -148,7 +148,7 @@ rag = LightRAG(
|
|
| 148 |
<details>
|
| 149 |
<summary> Using Ollama Models </summary>
|
| 150 |
If you want to use Ollama models, you only need to set LightRAG as follows:
|
| 151 |
-
|
| 152 |
```python
|
| 153 |
from lightrag.llm import ollama_model_complete, ollama_embedding
|
| 154 |
|
|
@@ -162,7 +162,7 @@ rag = LightRAG(
|
|
| 162 |
embedding_dim=768,
|
| 163 |
max_token_size=8192,
|
| 164 |
func=lambda texts: ollama_embedding(
|
| 165 |
-
texts,
|
| 166 |
embed_model="nomic-embed-text"
|
| 167 |
)
|
| 168 |
),
|
|
@@ -187,14 +187,14 @@ with open("./newText.txt") as f:
|
|
| 187 |
```
|
| 188 |
## Evaluation
|
| 189 |
### Dataset
|
| 190 |
-
The dataset used in LightRAG can be
|
| 191 |
|
| 192 |
### Generate Query
|
| 193 |
-
LightRAG uses the following prompt to generate high-level queries, with the corresponding code
|
| 194 |
|
| 195 |
<details>
|
| 196 |
<summary> Prompt </summary>
|
| 197 |
-
|
| 198 |
```python
|
| 199 |
Given the following description of a dataset:
|
| 200 |
|
|
@@ -219,18 +219,18 @@ Output the results in the following structure:
|
|
| 219 |
...
|
| 220 |
```
|
| 221 |
</details>
|
| 222 |
-
|
| 223 |
### Batch Eval
|
| 224 |
To evaluate the performance of two RAG systems on high-level queries, LightRAG uses the following prompt, with the specific code available in `example/batch_eval.py`.
|
| 225 |
|
| 226 |
<details>
|
| 227 |
<summary> Prompt </summary>
|
| 228 |
-
|
| 229 |
```python
|
| 230 |
---Role---
|
| 231 |
You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
| 232 |
---Goal---
|
| 233 |
-
You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
| 234 |
|
| 235 |
- **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?
|
| 236 |
- **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?
|
|
@@ -294,7 +294,7 @@ Output your evaluation in the following JSON format:
|
|
| 294 |
| **Empowerment** | 36.69% | **63.31%** | 45.09% | **54.91%** | 42.81% | **57.19%** | **52.94%** | 47.06% |
|
| 295 |
| **Overall** | 43.62% | **56.38%** | 45.98% | **54.02%** | 45.70% | **54.30%** | **51.86%** | 48.14% |
|
| 296 |
|
| 297 |
-
## Reproduce
|
| 298 |
All the code can be found in the `./reproduce` directory.
|
| 299 |
|
| 300 |
### Step-0 Extract Unique Contexts
|
|
@@ -302,7 +302,7 @@ First, we need to extract unique contexts in the datasets.
|
|
| 302 |
|
| 303 |
<details>
|
| 304 |
<summary> Code </summary>
|
| 305 |
-
|
| 306 |
```python
|
| 307 |
def extract_unique_contexts(input_directory, output_directory):
|
| 308 |
|
|
@@ -361,12 +361,12 @@ For the extracted contexts, we insert them into the LightRAG system.
|
|
| 361 |
|
| 362 |
<details>
|
| 363 |
<summary> Code </summary>
|
| 364 |
-
|
| 365 |
```python
|
| 366 |
def insert_text(rag, file_path):
|
| 367 |
with open(file_path, mode='r') as f:
|
| 368 |
unique_contexts = json.load(f)
|
| 369 |
-
|
| 370 |
retries = 0
|
| 371 |
max_retries = 3
|
| 372 |
while retries < max_retries:
|
|
@@ -384,11 +384,11 @@ def insert_text(rag, file_path):
|
|
| 384 |
|
| 385 |
### Step-2 Generate Queries
|
| 386 |
|
| 387 |
-
We extract tokens from
|
| 388 |
|
| 389 |
<details>
|
| 390 |
<summary> Code </summary>
|
| 391 |
-
|
| 392 |
```python
|
| 393 |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
| 394 |
|
|
@@ -401,7 +401,7 @@ def get_summary(context, tot_tokens=2000):
|
|
| 401 |
|
| 402 |
summary_tokens = start_tokens + end_tokens
|
| 403 |
summary = tokenizer.convert_tokens_to_string(summary_tokens)
|
| 404 |
-
|
| 405 |
return summary
|
| 406 |
```
|
| 407 |
</details>
|
|
@@ -411,12 +411,12 @@ For the queries generated in Step-2, we will extract them and query LightRAG.
|
|
| 411 |
|
| 412 |
<details>
|
| 413 |
<summary> Code </summary>
|
| 414 |
-
|
| 415 |
```python
|
| 416 |
def extract_queries(file_path):
|
| 417 |
with open(file_path, 'r') as f:
|
| 418 |
data = f.read()
|
| 419 |
-
|
| 420 |
data = data.replace('**', '')
|
| 421 |
|
| 422 |
queries = re.findall(r'- Question \d+: (.+)', data)
|
|
@@ -470,7 +470,7 @@ def extract_queries(file_path):
|
|
| 470 |
|
| 471 |
```python
|
| 472 |
@article{guo2024lightrag,
|
| 473 |
-
title={LightRAG: Simple and Fast Retrieval-Augmented Generation},
|
| 474 |
author={Zirui Guo and Lianghao Xia and Yanhua Yu and Tu Ao and Chao Huang},
|
| 475 |
year={2024},
|
| 476 |
eprint={2410.05779},
|
|
|
|
| 16 |
<a href="https://pypi.org/project/lightrag-hku/"><img src="https://img.shields.io/pypi/v/lightrag-hku.svg"></a>
|
| 17 |
<a href="https://pepy.tech/project/lightrag-hku"><img src="https://static.pepy.tech/badge/lightrag-hku/month"></a>
|
| 18 |
</p>
|
| 19 |
+
|
| 20 |
This repository hosts the code of LightRAG. The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag).
|
| 21 |

|
| 22 |
</div>
|
| 23 |
|
| 24 |
+
## 🎉 News
|
| 25 |
- [x] [2024.10.18]🎯🎯📢📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
|
| 26 |
- [x] [2024.10.17]🎯🎯📢📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
|
| 27 |
+
- [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
| 28 |
+
- [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
| 29 |
|
| 30 |
## Install
|
| 31 |
|
|
|
|
| 83 |
<details>
|
| 84 |
<summary> Using Open AI-like APIs </summary>
|
| 85 |
|
| 86 |
+
LightRAG also supports Open AI-like chat/embeddings APIs:
|
| 87 |
```python
|
| 88 |
async def llm_model_func(
|
| 89 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
|
|
|
| 120 |
|
| 121 |
<details>
|
| 122 |
<summary> Using Hugging Face Models </summary>
|
| 123 |
+
|
| 124 |
If you want to use Hugging Face models, you only need to set LightRAG as follows:
|
| 125 |
```python
|
| 126 |
from lightrag.llm import hf_model_complete, hf_embedding
|
|
|
|
| 136 |
embedding_dim=384,
|
| 137 |
max_token_size=5000,
|
| 138 |
func=lambda texts: hf_embedding(
|
| 139 |
+
texts,
|
| 140 |
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
|
| 141 |
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
| 142 |
)
|
|
|
|
| 148 |
<details>
|
| 149 |
<summary> Using Ollama Models </summary>
|
| 150 |
If you want to use Ollama models, you only need to set LightRAG as follows:
|
| 151 |
+
|
| 152 |
```python
|
| 153 |
from lightrag.llm import ollama_model_complete, ollama_embedding
|
| 154 |
|
|
|
|
| 162 |
embedding_dim=768,
|
| 163 |
max_token_size=8192,
|
| 164 |
func=lambda texts: ollama_embedding(
|
| 165 |
+
texts,
|
| 166 |
embed_model="nomic-embed-text"
|
| 167 |
)
|
| 168 |
),
|
|
|
|
| 187 |
```
|
| 188 |
## Evaluation
|
| 189 |
### Dataset
|
| 190 |
+
The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
|
| 191 |
|
| 192 |
### Generate Query
|
| 193 |
+
LightRAG uses the following prompt to generate high-level queries, with the corresponding code in `example/generate_query.py`.
|
| 194 |
|
| 195 |
<details>
|
| 196 |
<summary> Prompt </summary>
|
| 197 |
+
|
| 198 |
```python
|
| 199 |
Given the following description of a dataset:
|
| 200 |
|
|
|
|
| 219 |
...
|
| 220 |
```
|
| 221 |
</details>
|
| 222 |
+
|
| 223 |
### Batch Eval
|
| 224 |
To evaluate the performance of two RAG systems on high-level queries, LightRAG uses the following prompt, with the specific code available in `example/batch_eval.py`.
|
| 225 |
|
| 226 |
<details>
|
| 227 |
<summary> Prompt </summary>
|
| 228 |
+
|
| 229 |
```python
|
| 230 |
---Role---
|
| 231 |
You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
| 232 |
---Goal---
|
| 233 |
+
You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
| 234 |
|
| 235 |
- **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?
|
| 236 |
- **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?
|
|
|
|
| 294 |
| **Empowerment** | 36.69% | **63.31%** | 45.09% | **54.91%** | 42.81% | **57.19%** | **52.94%** | 47.06% |
|
| 295 |
| **Overall** | 43.62% | **56.38%** | 45.98% | **54.02%** | 45.70% | **54.30%** | **51.86%** | 48.14% |
|
| 296 |
|
| 297 |
+
## Reproduce
|
| 298 |
All the code can be found in the `./reproduce` directory.
|
| 299 |
|
| 300 |
### Step-0 Extract Unique Contexts
|
|
|
|
| 302 |
|
| 303 |
<details>
|
| 304 |
<summary> Code </summary>
|
| 305 |
+
|
| 306 |
```python
|
| 307 |
def extract_unique_contexts(input_directory, output_directory):
|
| 308 |
|
|
|
|
| 361 |
|
| 362 |
<details>
|
| 363 |
<summary> Code </summary>
|
| 364 |
+
|
| 365 |
```python
|
| 366 |
def insert_text(rag, file_path):
|
| 367 |
with open(file_path, mode='r') as f:
|
| 368 |
unique_contexts = json.load(f)
|
| 369 |
+
|
| 370 |
retries = 0
|
| 371 |
max_retries = 3
|
| 372 |
while retries < max_retries:
|
|
|
|
| 384 |
|
| 385 |
### Step-2 Generate Queries
|
| 386 |
|
| 387 |
+
We extract tokens from the first and the second half of each context in the dataset, then combine them as dataset descriptions to generate queries.
|
| 388 |
|
| 389 |
<details>
|
| 390 |
<summary> Code </summary>
|
| 391 |
+
|
| 392 |
```python
|
| 393 |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
| 394 |
|
|
|
|
| 401 |
|
| 402 |
summary_tokens = start_tokens + end_tokens
|
| 403 |
summary = tokenizer.convert_tokens_to_string(summary_tokens)
|
| 404 |
+
|
| 405 |
return summary
|
| 406 |
```
|
| 407 |
</details>
|
|
|
|
| 411 |
|
| 412 |
<details>
|
| 413 |
<summary> Code </summary>
|
| 414 |
+
|
| 415 |
```python
|
| 416 |
def extract_queries(file_path):
|
| 417 |
with open(file_path, 'r') as f:
|
| 418 |
data = f.read()
|
| 419 |
+
|
| 420 |
data = data.replace('**', '')
|
| 421 |
|
| 422 |
queries = re.findall(r'- Question \d+: (.+)', data)
|
|
|
|
| 470 |
|
| 471 |
```python
|
| 472 |
@article{guo2024lightrag,
|
| 473 |
+
title={LightRAG: Simple and Fast Retrieval-Augmented Generation},
|
| 474 |
author={Zirui Guo and Lianghao Xia and Yanhua Yu and Tu Ao and Chao Huang},
|
| 475 |
year={2024},
|
| 476 |
eprint={2410.05779},
|
examples/batch_eval.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
import os
|
| 2 |
import re
|
| 3 |
import json
|
| 4 |
import jsonlines
|
|
@@ -9,28 +8,28 @@ from openai import OpenAI
|
|
| 9 |
def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
| 10 |
client = OpenAI()
|
| 11 |
|
| 12 |
-
with open(query_file,
|
| 13 |
data = f.read()
|
| 14 |
|
| 15 |
-
queries = re.findall(r
|
| 16 |
|
| 17 |
-
with open(result1_file,
|
| 18 |
answers1 = json.load(f)
|
| 19 |
-
answers1 = [i[
|
| 20 |
|
| 21 |
-
with open(result2_file,
|
| 22 |
answers2 = json.load(f)
|
| 23 |
-
answers2 = [i[
|
| 24 |
|
| 25 |
requests = []
|
| 26 |
for i, (query, answer1, answer2) in enumerate(zip(queries, answers1, answers2)):
|
| 27 |
-
sys_prompt =
|
| 28 |
---Role---
|
| 29 |
You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
| 30 |
"""
|
| 31 |
|
| 32 |
prompt = f"""
|
| 33 |
-
You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
| 34 |
|
| 35 |
- **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?
|
| 36 |
- **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?
|
|
@@ -69,7 +68,6 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
|
| 69 |
}}
|
| 70 |
"""
|
| 71 |
|
| 72 |
-
|
| 73 |
request_data = {
|
| 74 |
"custom_id": f"request-{i+1}",
|
| 75 |
"method": "POST",
|
|
@@ -78,22 +76,21 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
|
| 78 |
"model": "gpt-4o-mini",
|
| 79 |
"messages": [
|
| 80 |
{"role": "system", "content": sys_prompt},
|
| 81 |
-
{"role": "user", "content": prompt}
|
| 82 |
],
|
| 83 |
-
}
|
| 84 |
}
|
| 85 |
-
|
| 86 |
requests.append(request_data)
|
| 87 |
|
| 88 |
-
with jsonlines.open(output_file_path, mode=
|
| 89 |
for request in requests:
|
| 90 |
writer.write(request)
|
| 91 |
|
| 92 |
print(f"Batch API requests written to {output_file_path}")
|
| 93 |
|
| 94 |
batch_input_file = client.files.create(
|
| 95 |
-
file=open(output_file_path, "rb"),
|
| 96 |
-
purpose="batch"
|
| 97 |
)
|
| 98 |
batch_input_file_id = batch_input_file.id
|
| 99 |
|
|
@@ -101,12 +98,11 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
|
| 101 |
input_file_id=batch_input_file_id,
|
| 102 |
endpoint="/v1/chat/completions",
|
| 103 |
completion_window="24h",
|
| 104 |
-
metadata={
|
| 105 |
-
"description": "nightly eval job"
|
| 106 |
-
}
|
| 107 |
)
|
| 108 |
|
| 109 |
-
print(f
|
|
|
|
| 110 |
|
| 111 |
if __name__ == "__main__":
|
| 112 |
-
batch_eval()
|
|
|
|
|
|
|
| 1 |
import re
|
| 2 |
import json
|
| 3 |
import jsonlines
|
|
|
|
| 8 |
def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
| 9 |
client = OpenAI()
|
| 10 |
|
| 11 |
+
with open(query_file, "r") as f:
|
| 12 |
data = f.read()
|
| 13 |
|
| 14 |
+
queries = re.findall(r"- Question \d+: (.+)", data)
|
| 15 |
|
| 16 |
+
with open(result1_file, "r") as f:
|
| 17 |
answers1 = json.load(f)
|
| 18 |
+
answers1 = [i["result"] for i in answers1]
|
| 19 |
|
| 20 |
+
with open(result2_file, "r") as f:
|
| 21 |
answers2 = json.load(f)
|
| 22 |
+
answers2 = [i["result"] for i in answers2]
|
| 23 |
|
| 24 |
requests = []
|
| 25 |
for i, (query, answer1, answer2) in enumerate(zip(queries, answers1, answers2)):
|
| 26 |
+
sys_prompt = """
|
| 27 |
---Role---
|
| 28 |
You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
| 29 |
"""
|
| 30 |
|
| 31 |
prompt = f"""
|
| 32 |
+
You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
| 33 |
|
| 34 |
- **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?
|
| 35 |
- **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?
|
|
|
|
| 68 |
}}
|
| 69 |
"""
|
| 70 |
|
|
|
|
| 71 |
request_data = {
|
| 72 |
"custom_id": f"request-{i+1}",
|
| 73 |
"method": "POST",
|
|
|
|
| 76 |
"model": "gpt-4o-mini",
|
| 77 |
"messages": [
|
| 78 |
{"role": "system", "content": sys_prompt},
|
| 79 |
+
{"role": "user", "content": prompt},
|
| 80 |
],
|
| 81 |
+
},
|
| 82 |
}
|
| 83 |
+
|
| 84 |
requests.append(request_data)
|
| 85 |
|
| 86 |
+
with jsonlines.open(output_file_path, mode="w") as writer:
|
| 87 |
for request in requests:
|
| 88 |
writer.write(request)
|
| 89 |
|
| 90 |
print(f"Batch API requests written to {output_file_path}")
|
| 91 |
|
| 92 |
batch_input_file = client.files.create(
|
| 93 |
+
file=open(output_file_path, "rb"), purpose="batch"
|
|
|
|
| 94 |
)
|
| 95 |
batch_input_file_id = batch_input_file.id
|
| 96 |
|
|
|
|
| 98 |
input_file_id=batch_input_file_id,
|
| 99 |
endpoint="/v1/chat/completions",
|
| 100 |
completion_window="24h",
|
| 101 |
+
metadata={"description": "nightly eval job"},
|
|
|
|
|
|
|
| 102 |
)
|
| 103 |
|
| 104 |
+
print(f"Batch {batch.id} has been created.")
|
| 105 |
+
|
| 106 |
|
| 107 |
if __name__ == "__main__":
|
| 108 |
+
batch_eval()
|
examples/generate_query.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
| 1 |
-
import os
|
| 2 |
-
|
| 3 |
from openai import OpenAI
|
| 4 |
|
| 5 |
# os.environ["OPENAI_API_KEY"] = ""
|
| 6 |
|
|
|
|
| 7 |
def openai_complete_if_cache(
|
| 8 |
model="gpt-4o-mini", prompt=None, system_prompt=None, history_messages=[], **kwargs
|
| 9 |
) -> str:
|
|
@@ -47,10 +46,10 @@ if __name__ == "__main__":
|
|
| 47 |
...
|
| 48 |
"""
|
| 49 |
|
| 50 |
-
result = openai_complete_if_cache(model=
|
| 51 |
|
| 52 |
-
file_path =
|
| 53 |
with open(file_path, "w") as file:
|
| 54 |
file.write(result)
|
| 55 |
|
| 56 |
-
print(f"Queries written to {file_path}")
|
|
|
|
|
|
|
|
|
|
| 1 |
from openai import OpenAI
|
| 2 |
|
| 3 |
# os.environ["OPENAI_API_KEY"] = ""
|
| 4 |
|
| 5 |
+
|
| 6 |
def openai_complete_if_cache(
|
| 7 |
model="gpt-4o-mini", prompt=None, system_prompt=None, history_messages=[], **kwargs
|
| 8 |
) -> str:
|
|
|
|
| 46 |
...
|
| 47 |
"""
|
| 48 |
|
| 49 |
+
result = openai_complete_if_cache(model="gpt-4o-mini", prompt=prompt)
|
| 50 |
|
| 51 |
+
file_path = "./queries.txt"
|
| 52 |
with open(file_path, "w") as file:
|
| 53 |
file.write(result)
|
| 54 |
|
| 55 |
+
print(f"Queries written to {file_path}")
|
examples/lightrag_azure_openai_demo.py
CHANGED
|
@@ -122,4 +122,4 @@ print("\nResult (Global):")
|
|
| 122 |
print(rag.query(query_text, param=QueryParam(mode="global")))
|
| 123 |
|
| 124 |
print("\nResult (Hybrid):")
|
| 125 |
-
print(rag.query(query_text, param=QueryParam(mode="hybrid")))
|
|
|
|
| 122 |
print(rag.query(query_text, param=QueryParam(mode="global")))
|
| 123 |
|
| 124 |
print("\nResult (Hybrid):")
|
| 125 |
+
print(rag.query(query_text, param=QueryParam(mode="hybrid")))
|
examples/lightrag_bedrock_demo.py
CHANGED
|
@@ -20,13 +20,11 @@ rag = LightRAG(
|
|
| 20 |
llm_model_func=bedrock_complete,
|
| 21 |
llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock",
|
| 22 |
embedding_func=EmbeddingFunc(
|
| 23 |
-
embedding_dim=1024,
|
| 24 |
-
|
| 25 |
-
func=bedrock_embedding
|
| 26 |
-
)
|
| 27 |
)
|
| 28 |
|
| 29 |
-
with open("./book.txt",
|
| 30 |
rag.insert(f.read())
|
| 31 |
|
| 32 |
for mode in ["naive", "local", "global", "hybrid"]:
|
|
@@ -34,8 +32,5 @@ for mode in ["naive", "local", "global", "hybrid"]:
|
|
| 34 |
print(f"| {mode.capitalize()} |")
|
| 35 |
print("+-" + "-" * len(mode) + "-+\n")
|
| 36 |
print(
|
| 37 |
-
rag.query(
|
| 38 |
-
"What are the top themes in this story?",
|
| 39 |
-
param=QueryParam(mode=mode)
|
| 40 |
-
)
|
| 41 |
)
|
|
|
|
| 20 |
llm_model_func=bedrock_complete,
|
| 21 |
llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock",
|
| 22 |
embedding_func=EmbeddingFunc(
|
| 23 |
+
embedding_dim=1024, max_token_size=8192, func=bedrock_embedding
|
| 24 |
+
),
|
|
|
|
|
|
|
| 25 |
)
|
| 26 |
|
| 27 |
+
with open("./book.txt", "r", encoding="utf-8") as f:
|
| 28 |
rag.insert(f.read())
|
| 29 |
|
| 30 |
for mode in ["naive", "local", "global", "hybrid"]:
|
|
|
|
| 32 |
print(f"| {mode.capitalize()} |")
|
| 33 |
print("+-" + "-" * len(mode) + "-+\n")
|
| 34 |
print(
|
| 35 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode=mode))
|
|
|
|
|
|
|
|
|
|
| 36 |
)
|
examples/lightrag_hf_demo.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
| 1 |
import os
|
| 2 |
-
import sys
|
| 3 |
|
| 4 |
from lightrag import LightRAG, QueryParam
|
| 5 |
from lightrag.llm import hf_model_complete, hf_embedding
|
| 6 |
from lightrag.utils import EmbeddingFunc
|
| 7 |
-
from transformers import AutoModel,AutoTokenizer
|
| 8 |
|
| 9 |
WORKING_DIR = "./dickens"
|
| 10 |
|
|
@@ -13,16 +12,20 @@ if not os.path.exists(WORKING_DIR):
|
|
| 13 |
|
| 14 |
rag = LightRAG(
|
| 15 |
working_dir=WORKING_DIR,
|
| 16 |
-
llm_model_func=hf_model_complete,
|
| 17 |
-
llm_model_name=
|
| 18 |
embedding_func=EmbeddingFunc(
|
| 19 |
embedding_dim=384,
|
| 20 |
max_token_size=5000,
|
| 21 |
func=lambda texts: hf_embedding(
|
| 22 |
-
texts,
|
| 23 |
-
tokenizer=AutoTokenizer.from_pretrained(
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
),
|
| 27 |
)
|
| 28 |
|
|
@@ -31,13 +34,21 @@ with open("./book.txt") as f:
|
|
| 31 |
rag.insert(f.read())
|
| 32 |
|
| 33 |
# Perform naive search
|
| 34 |
-
print(
|
|
|
|
|
|
|
| 35 |
|
| 36 |
# Perform local search
|
| 37 |
-
print(
|
|
|
|
|
|
|
| 38 |
|
| 39 |
# Perform global search
|
| 40 |
-
print(
|
|
|
|
|
|
|
| 41 |
|
| 42 |
# Perform hybrid search
|
| 43 |
-
print(
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
|
| 3 |
from lightrag import LightRAG, QueryParam
|
| 4 |
from lightrag.llm import hf_model_complete, hf_embedding
|
| 5 |
from lightrag.utils import EmbeddingFunc
|
| 6 |
+
from transformers import AutoModel, AutoTokenizer
|
| 7 |
|
| 8 |
WORKING_DIR = "./dickens"
|
| 9 |
|
|
|
|
| 12 |
|
| 13 |
rag = LightRAG(
|
| 14 |
working_dir=WORKING_DIR,
|
| 15 |
+
llm_model_func=hf_model_complete,
|
| 16 |
+
llm_model_name="meta-llama/Llama-3.1-8B-Instruct",
|
| 17 |
embedding_func=EmbeddingFunc(
|
| 18 |
embedding_dim=384,
|
| 19 |
max_token_size=5000,
|
| 20 |
func=lambda texts: hf_embedding(
|
| 21 |
+
texts,
|
| 22 |
+
tokenizer=AutoTokenizer.from_pretrained(
|
| 23 |
+
"sentence-transformers/all-MiniLM-L6-v2"
|
| 24 |
+
),
|
| 25 |
+
embed_model=AutoModel.from_pretrained(
|
| 26 |
+
"sentence-transformers/all-MiniLM-L6-v2"
|
| 27 |
+
),
|
| 28 |
+
),
|
| 29 |
),
|
| 30 |
)
|
| 31 |
|
|
|
|
| 34 |
rag.insert(f.read())
|
| 35 |
|
| 36 |
# Perform naive search
|
| 37 |
+
print(
|
| 38 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
| 39 |
+
)
|
| 40 |
|
| 41 |
# Perform local search
|
| 42 |
+
print(
|
| 43 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
| 44 |
+
)
|
| 45 |
|
| 46 |
# Perform global search
|
| 47 |
+
print(
|
| 48 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
| 49 |
+
)
|
| 50 |
|
| 51 |
# Perform hybrid search
|
| 52 |
+
print(
|
| 53 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
| 54 |
+
)
|
examples/lightrag_ollama_demo.py
CHANGED
|
@@ -11,15 +11,12 @@ if not os.path.exists(WORKING_DIR):
|
|
| 11 |
|
| 12 |
rag = LightRAG(
|
| 13 |
working_dir=WORKING_DIR,
|
| 14 |
-
llm_model_func=ollama_model_complete,
|
| 15 |
-
llm_model_name=
|
| 16 |
embedding_func=EmbeddingFunc(
|
| 17 |
embedding_dim=768,
|
| 18 |
max_token_size=8192,
|
| 19 |
-
func=lambda texts: ollama_embedding(
|
| 20 |
-
texts,
|
| 21 |
-
embed_model="nomic-embed-text"
|
| 22 |
-
)
|
| 23 |
),
|
| 24 |
)
|
| 25 |
|
|
@@ -28,13 +25,21 @@ with open("./book.txt") as f:
|
|
| 28 |
rag.insert(f.read())
|
| 29 |
|
| 30 |
# Perform naive search
|
| 31 |
-
print(
|
|
|
|
|
|
|
| 32 |
|
| 33 |
# Perform local search
|
| 34 |
-
print(
|
|
|
|
|
|
|
| 35 |
|
| 36 |
# Perform global search
|
| 37 |
-
print(
|
|
|
|
|
|
|
| 38 |
|
| 39 |
# Perform hybrid search
|
| 40 |
-
print(
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
rag = LightRAG(
|
| 13 |
working_dir=WORKING_DIR,
|
| 14 |
+
llm_model_func=ollama_model_complete,
|
| 15 |
+
llm_model_name="your_model_name",
|
| 16 |
embedding_func=EmbeddingFunc(
|
| 17 |
embedding_dim=768,
|
| 18 |
max_token_size=8192,
|
| 19 |
+
func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"),
|
|
|
|
|
|
|
|
|
|
| 20 |
),
|
| 21 |
)
|
| 22 |
|
|
|
|
| 25 |
rag.insert(f.read())
|
| 26 |
|
| 27 |
# Perform naive search
|
| 28 |
+
print(
|
| 29 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
| 30 |
+
)
|
| 31 |
|
| 32 |
# Perform local search
|
| 33 |
+
print(
|
| 34 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
| 35 |
+
)
|
| 36 |
|
| 37 |
# Perform global search
|
| 38 |
+
print(
|
| 39 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
| 40 |
+
)
|
| 41 |
|
| 42 |
# Perform hybrid search
|
| 43 |
+
print(
|
| 44 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
| 45 |
+
)
|
examples/lightrag_openai_compatible_demo.py
CHANGED
|
@@ -6,10 +6,11 @@ from lightrag.utils import EmbeddingFunc
|
|
| 6 |
import numpy as np
|
| 7 |
|
| 8 |
WORKING_DIR = "./dickens"
|
| 9 |
-
|
| 10 |
if not os.path.exists(WORKING_DIR):
|
| 11 |
os.mkdir(WORKING_DIR)
|
| 12 |
|
|
|
|
| 13 |
async def llm_model_func(
|
| 14 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 15 |
) -> str:
|
|
@@ -20,17 +21,19 @@ async def llm_model_func(
|
|
| 20 |
history_messages=history_messages,
|
| 21 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
| 22 |
base_url="https://api.upstage.ai/v1/solar",
|
| 23 |
-
**kwargs
|
| 24 |
)
|
| 25 |
|
|
|
|
| 26 |
async def embedding_func(texts: list[str]) -> np.ndarray:
|
| 27 |
return await openai_embedding(
|
| 28 |
texts,
|
| 29 |
model="solar-embedding-1-large-query",
|
| 30 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
| 31 |
-
base_url="https://api.upstage.ai/v1/solar"
|
| 32 |
)
|
| 33 |
|
|
|
|
| 34 |
# function test
|
| 35 |
async def test_funcs():
|
| 36 |
result = await llm_model_func("How are you?")
|
|
@@ -39,6 +42,7 @@ async def test_funcs():
|
|
| 39 |
result = await embedding_func(["How are you?"])
|
| 40 |
print("embedding_func: ", result)
|
| 41 |
|
|
|
|
| 42 |
asyncio.run(test_funcs())
|
| 43 |
|
| 44 |
|
|
@@ -46,10 +50,8 @@ rag = LightRAG(
|
|
| 46 |
working_dir=WORKING_DIR,
|
| 47 |
llm_model_func=llm_model_func,
|
| 48 |
embedding_func=EmbeddingFunc(
|
| 49 |
-
embedding_dim=4096,
|
| 50 |
-
|
| 51 |
-
func=embedding_func
|
| 52 |
-
)
|
| 53 |
)
|
| 54 |
|
| 55 |
|
|
@@ -57,13 +59,21 @@ with open("./book.txt") as f:
|
|
| 57 |
rag.insert(f.read())
|
| 58 |
|
| 59 |
# Perform naive search
|
| 60 |
-
print(
|
|
|
|
|
|
|
| 61 |
|
| 62 |
# Perform local search
|
| 63 |
-
print(
|
|
|
|
|
|
|
| 64 |
|
| 65 |
# Perform global search
|
| 66 |
-
print(
|
|
|
|
|
|
|
| 67 |
|
| 68 |
# Perform hybrid search
|
| 69 |
-
print(
|
|
|
|
|
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
|
| 8 |
WORKING_DIR = "./dickens"
|
| 9 |
+
|
| 10 |
if not os.path.exists(WORKING_DIR):
|
| 11 |
os.mkdir(WORKING_DIR)
|
| 12 |
|
| 13 |
+
|
| 14 |
async def llm_model_func(
|
| 15 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 16 |
) -> str:
|
|
|
|
| 21 |
history_messages=history_messages,
|
| 22 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
| 23 |
base_url="https://api.upstage.ai/v1/solar",
|
| 24 |
+
**kwargs,
|
| 25 |
)
|
| 26 |
|
| 27 |
+
|
| 28 |
async def embedding_func(texts: list[str]) -> np.ndarray:
|
| 29 |
return await openai_embedding(
|
| 30 |
texts,
|
| 31 |
model="solar-embedding-1-large-query",
|
| 32 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
| 33 |
+
base_url="https://api.upstage.ai/v1/solar",
|
| 34 |
)
|
| 35 |
|
| 36 |
+
|
| 37 |
# function test
|
| 38 |
async def test_funcs():
|
| 39 |
result = await llm_model_func("How are you?")
|
|
|
|
| 42 |
result = await embedding_func(["How are you?"])
|
| 43 |
print("embedding_func: ", result)
|
| 44 |
|
| 45 |
+
|
| 46 |
asyncio.run(test_funcs())
|
| 47 |
|
| 48 |
|
|
|
|
| 50 |
working_dir=WORKING_DIR,
|
| 51 |
llm_model_func=llm_model_func,
|
| 52 |
embedding_func=EmbeddingFunc(
|
| 53 |
+
embedding_dim=4096, max_token_size=8192, func=embedding_func
|
| 54 |
+
),
|
|
|
|
|
|
|
| 55 |
)
|
| 56 |
|
| 57 |
|
|
|
|
| 59 |
rag.insert(f.read())
|
| 60 |
|
| 61 |
# Perform naive search
|
| 62 |
+
print(
|
| 63 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
| 64 |
+
)
|
| 65 |
|
| 66 |
# Perform local search
|
| 67 |
+
print(
|
| 68 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
| 69 |
+
)
|
| 70 |
|
| 71 |
# Perform global search
|
| 72 |
+
print(
|
| 73 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
| 74 |
+
)
|
| 75 |
|
| 76 |
# Perform hybrid search
|
| 77 |
+
print(
|
| 78 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
| 79 |
+
)
|
examples/lightrag_openai_demo.py
CHANGED
|
@@ -1,9 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
-
import sys
|
| 3 |
|
| 4 |
from lightrag import LightRAG, QueryParam
|
| 5 |
-
from lightrag.llm import gpt_4o_mini_complete
|
| 6 |
-
from transformers import AutoModel,AutoTokenizer
|
| 7 |
|
| 8 |
WORKING_DIR = "./dickens"
|
| 9 |
|
|
@@ -12,7 +10,7 @@ if not os.path.exists(WORKING_DIR):
|
|
| 12 |
|
| 13 |
rag = LightRAG(
|
| 14 |
working_dir=WORKING_DIR,
|
| 15 |
-
llm_model_func=gpt_4o_mini_complete
|
| 16 |
# llm_model_func=gpt_4o_complete
|
| 17 |
)
|
| 18 |
|
|
@@ -21,13 +19,21 @@ with open("./book.txt") as f:
|
|
| 21 |
rag.insert(f.read())
|
| 22 |
|
| 23 |
# Perform naive search
|
| 24 |
-
print(
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# Perform local search
|
| 27 |
-
print(
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# Perform global search
|
| 30 |
-
print(
|
|
|
|
|
|
|
| 31 |
|
| 32 |
# Perform hybrid search
|
| 33 |
-
print(
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
|
| 3 |
from lightrag import LightRAG, QueryParam
|
| 4 |
+
from lightrag.llm import gpt_4o_mini_complete
|
|
|
|
| 5 |
|
| 6 |
WORKING_DIR = "./dickens"
|
| 7 |
|
|
|
|
| 10 |
|
| 11 |
rag = LightRAG(
|
| 12 |
working_dir=WORKING_DIR,
|
| 13 |
+
llm_model_func=gpt_4o_mini_complete,
|
| 14 |
# llm_model_func=gpt_4o_complete
|
| 15 |
)
|
| 16 |
|
|
|
|
| 19 |
rag.insert(f.read())
|
| 20 |
|
| 21 |
# Perform naive search
|
| 22 |
+
print(
|
| 23 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
| 24 |
+
)
|
| 25 |
|
| 26 |
# Perform local search
|
| 27 |
+
print(
|
| 28 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
| 29 |
+
)
|
| 30 |
|
| 31 |
# Perform global search
|
| 32 |
+
print(
|
| 33 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
| 34 |
+
)
|
| 35 |
|
| 36 |
# Perform hybrid search
|
| 37 |
+
print(
|
| 38 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
| 39 |
+
)
|
lightrag/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from .lightrag import LightRAG, QueryParam
|
| 2 |
|
| 3 |
__version__ = "0.0.6"
|
| 4 |
__author__ = "Zirui Guo"
|
|
|
|
| 1 |
+
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
| 2 |
|
| 3 |
__version__ = "0.0.6"
|
| 4 |
__author__ = "Zirui Guo"
|
lightrag/base.py
CHANGED
|
@@ -12,15 +12,16 @@ TextChunkSchema = TypedDict(
|
|
| 12 |
|
| 13 |
T = TypeVar("T")
|
| 14 |
|
|
|
|
| 15 |
@dataclass
|
| 16 |
class QueryParam:
|
| 17 |
mode: Literal["local", "global", "hybrid", "naive"] = "global"
|
| 18 |
only_need_context: bool = False
|
| 19 |
response_type: str = "Multiple Paragraphs"
|
| 20 |
top_k: int = 60
|
| 21 |
-
max_token_for_text_unit: int = 4000
|
| 22 |
max_token_for_global_context: int = 4000
|
| 23 |
-
max_token_for_local_context: int = 4000
|
| 24 |
|
| 25 |
|
| 26 |
@dataclass
|
|
@@ -36,6 +37,7 @@ class StorageNameSpace:
|
|
| 36 |
"""commit the storage operations after querying"""
|
| 37 |
pass
|
| 38 |
|
|
|
|
| 39 |
@dataclass
|
| 40 |
class BaseVectorStorage(StorageNameSpace):
|
| 41 |
embedding_func: EmbeddingFunc
|
|
@@ -50,6 +52,7 @@ class BaseVectorStorage(StorageNameSpace):
|
|
| 50 |
"""
|
| 51 |
raise NotImplementedError
|
| 52 |
|
|
|
|
| 53 |
@dataclass
|
| 54 |
class BaseKVStorage(Generic[T], StorageNameSpace):
|
| 55 |
async def all_keys(self) -> list[str]:
|
|
@@ -72,7 +75,7 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
|
|
| 72 |
|
| 73 |
async def drop(self):
|
| 74 |
raise NotImplementedError
|
| 75 |
-
|
| 76 |
|
| 77 |
@dataclass
|
| 78 |
class BaseGraphStorage(StorageNameSpace):
|
|
@@ -113,4 +116,4 @@ class BaseGraphStorage(StorageNameSpace):
|
|
| 113 |
raise NotImplementedError
|
| 114 |
|
| 115 |
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
| 116 |
-
raise NotImplementedError("Node embedding is not used in lightrag.")
|
|
|
|
| 12 |
|
| 13 |
T = TypeVar("T")
|
| 14 |
|
| 15 |
+
|
| 16 |
@dataclass
|
| 17 |
class QueryParam:
|
| 18 |
mode: Literal["local", "global", "hybrid", "naive"] = "global"
|
| 19 |
only_need_context: bool = False
|
| 20 |
response_type: str = "Multiple Paragraphs"
|
| 21 |
top_k: int = 60
|
| 22 |
+
max_token_for_text_unit: int = 4000
|
| 23 |
max_token_for_global_context: int = 4000
|
| 24 |
+
max_token_for_local_context: int = 4000
|
| 25 |
|
| 26 |
|
| 27 |
@dataclass
|
|
|
|
| 37 |
"""commit the storage operations after querying"""
|
| 38 |
pass
|
| 39 |
|
| 40 |
+
|
| 41 |
@dataclass
|
| 42 |
class BaseVectorStorage(StorageNameSpace):
|
| 43 |
embedding_func: EmbeddingFunc
|
|
|
|
| 52 |
"""
|
| 53 |
raise NotImplementedError
|
| 54 |
|
| 55 |
+
|
| 56 |
@dataclass
|
| 57 |
class BaseKVStorage(Generic[T], StorageNameSpace):
|
| 58 |
async def all_keys(self) -> list[str]:
|
|
|
|
| 75 |
|
| 76 |
async def drop(self):
|
| 77 |
raise NotImplementedError
|
| 78 |
+
|
| 79 |
|
| 80 |
@dataclass
|
| 81 |
class BaseGraphStorage(StorageNameSpace):
|
|
|
|
| 116 |
raise NotImplementedError
|
| 117 |
|
| 118 |
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
| 119 |
+
raise NotImplementedError("Node embedding is not used in lightrag.")
|
lightrag/lightrag.py
CHANGED
|
@@ -3,10 +3,12 @@ import os
|
|
| 3 |
from dataclasses import asdict, dataclass, field
|
| 4 |
from datetime import datetime
|
| 5 |
from functools import partial
|
| 6 |
-
from typing import Type, cast
|
| 7 |
-
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
|
| 8 |
|
| 9 |
-
from .llm import
|
|
|
|
|
|
|
|
|
|
| 10 |
from .operate import (
|
| 11 |
chunking_by_token_size,
|
| 12 |
extract_entities,
|
|
@@ -37,6 +39,7 @@ from .base import (
|
|
| 37 |
QueryParam,
|
| 38 |
)
|
| 39 |
|
|
|
|
| 40 |
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
| 41 |
try:
|
| 42 |
loop = asyncio.get_running_loop()
|
|
@@ -69,7 +72,6 @@ class LightRAG:
|
|
| 69 |
"dimensions": 1536,
|
| 70 |
"num_walks": 10,
|
| 71 |
"walk_length": 40,
|
| 72 |
-
"num_walks": 10,
|
| 73 |
"window_size": 2,
|
| 74 |
"iterations": 3,
|
| 75 |
"random_seed": 3,
|
|
@@ -77,13 +79,13 @@ class LightRAG:
|
|
| 77 |
)
|
| 78 |
|
| 79 |
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
|
| 80 |
-
embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)
|
| 81 |
embedding_batch_num: int = 32
|
| 82 |
embedding_func_max_async: int = 16
|
| 83 |
|
| 84 |
# LLM
|
| 85 |
-
llm_model_func: callable = gpt_4o_mini_complete#hf_model_complete#
|
| 86 |
-
llm_model_name: str =
|
| 87 |
llm_model_max_token_size: int = 32768
|
| 88 |
llm_model_max_async: int = 16
|
| 89 |
|
|
@@ -98,11 +100,11 @@ class LightRAG:
|
|
| 98 |
addon_params: dict = field(default_factory=dict)
|
| 99 |
convert_response_to_json_func: callable = convert_response_to_json
|
| 100 |
|
| 101 |
-
def __post_init__(self):
|
| 102 |
log_file = os.path.join(self.working_dir, "lightrag.log")
|
| 103 |
set_logger(log_file)
|
| 104 |
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
| 105 |
-
|
| 106 |
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
| 107 |
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
| 108 |
|
|
@@ -133,30 +135,24 @@ class LightRAG:
|
|
| 133 |
self.embedding_func
|
| 134 |
)
|
| 135 |
|
| 136 |
-
self.entities_vdb = (
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
meta_fields={"entity_name"}
|
| 142 |
-
)
|
| 143 |
)
|
| 144 |
-
self.relationships_vdb = (
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
meta_fields={"src_id", "tgt_id"}
|
| 150 |
-
)
|
| 151 |
)
|
| 152 |
-
self.chunks_vdb = (
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
embedding_func=self.embedding_func,
|
| 157 |
-
)
|
| 158 |
)
|
| 159 |
-
|
| 160 |
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
| 161 |
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
|
| 162 |
)
|
|
@@ -177,7 +173,7 @@ class LightRAG:
|
|
| 177 |
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
|
| 178 |
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
| 179 |
if not len(new_docs):
|
| 180 |
-
logger.warning(
|
| 181 |
return
|
| 182 |
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
| 183 |
|
|
@@ -203,7 +199,7 @@ class LightRAG:
|
|
| 203 |
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
| 204 |
}
|
| 205 |
if not len(inserting_chunks):
|
| 206 |
-
logger.warning(
|
| 207 |
return
|
| 208 |
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
|
| 209 |
|
|
@@ -246,7 +242,7 @@ class LightRAG:
|
|
| 246 |
def query(self, query: str, param: QueryParam = QueryParam()):
|
| 247 |
loop = always_get_an_event_loop()
|
| 248 |
return loop.run_until_complete(self.aquery(query, param))
|
| 249 |
-
|
| 250 |
async def aquery(self, query: str, param: QueryParam = QueryParam()):
|
| 251 |
if param.mode == "local":
|
| 252 |
response = await local_query(
|
|
@@ -290,7 +286,6 @@ class LightRAG:
|
|
| 290 |
raise ValueError(f"Unknown mode {param.mode}")
|
| 291 |
await self._query_done()
|
| 292 |
return response
|
| 293 |
-
|
| 294 |
|
| 295 |
async def _query_done(self):
|
| 296 |
tasks = []
|
|
@@ -299,5 +294,3 @@ class LightRAG:
|
|
| 299 |
continue
|
| 300 |
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
| 301 |
await asyncio.gather(*tasks)
|
| 302 |
-
|
| 303 |
-
|
|
|
|
| 3 |
from dataclasses import asdict, dataclass, field
|
| 4 |
from datetime import datetime
|
| 5 |
from functools import partial
|
| 6 |
+
from typing import Type, cast
|
|
|
|
| 7 |
|
| 8 |
+
from .llm import (
|
| 9 |
+
gpt_4o_mini_complete,
|
| 10 |
+
openai_embedding,
|
| 11 |
+
)
|
| 12 |
from .operate import (
|
| 13 |
chunking_by_token_size,
|
| 14 |
extract_entities,
|
|
|
|
| 39 |
QueryParam,
|
| 40 |
)
|
| 41 |
|
| 42 |
+
|
| 43 |
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
| 44 |
try:
|
| 45 |
loop = asyncio.get_running_loop()
|
|
|
|
| 72 |
"dimensions": 1536,
|
| 73 |
"num_walks": 10,
|
| 74 |
"walk_length": 40,
|
|
|
|
| 75 |
"window_size": 2,
|
| 76 |
"iterations": 3,
|
| 77 |
"random_seed": 3,
|
|
|
|
| 79 |
)
|
| 80 |
|
| 81 |
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
|
| 82 |
+
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
|
| 83 |
embedding_batch_num: int = 32
|
| 84 |
embedding_func_max_async: int = 16
|
| 85 |
|
| 86 |
# LLM
|
| 87 |
+
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
|
| 88 |
+
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
| 89 |
llm_model_max_token_size: int = 32768
|
| 90 |
llm_model_max_async: int = 16
|
| 91 |
|
|
|
|
| 100 |
addon_params: dict = field(default_factory=dict)
|
| 101 |
convert_response_to_json_func: callable = convert_response_to_json
|
| 102 |
|
| 103 |
+
def __post_init__(self):
|
| 104 |
log_file = os.path.join(self.working_dir, "lightrag.log")
|
| 105 |
set_logger(log_file)
|
| 106 |
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
| 107 |
+
|
| 108 |
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
| 109 |
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
| 110 |
|
|
|
|
| 135 |
self.embedding_func
|
| 136 |
)
|
| 137 |
|
| 138 |
+
self.entities_vdb = self.vector_db_storage_cls(
|
| 139 |
+
namespace="entities",
|
| 140 |
+
global_config=asdict(self),
|
| 141 |
+
embedding_func=self.embedding_func,
|
| 142 |
+
meta_fields={"entity_name"},
|
|
|
|
|
|
|
| 143 |
)
|
| 144 |
+
self.relationships_vdb = self.vector_db_storage_cls(
|
| 145 |
+
namespace="relationships",
|
| 146 |
+
global_config=asdict(self),
|
| 147 |
+
embedding_func=self.embedding_func,
|
| 148 |
+
meta_fields={"src_id", "tgt_id"},
|
|
|
|
|
|
|
| 149 |
)
|
| 150 |
+
self.chunks_vdb = self.vector_db_storage_cls(
|
| 151 |
+
namespace="chunks",
|
| 152 |
+
global_config=asdict(self),
|
| 153 |
+
embedding_func=self.embedding_func,
|
|
|
|
|
|
|
| 154 |
)
|
| 155 |
+
|
| 156 |
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
| 157 |
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
|
| 158 |
)
|
|
|
|
| 173 |
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
|
| 174 |
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
| 175 |
if not len(new_docs):
|
| 176 |
+
logger.warning("All docs are already in the storage")
|
| 177 |
return
|
| 178 |
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
| 179 |
|
|
|
|
| 199 |
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
| 200 |
}
|
| 201 |
if not len(inserting_chunks):
|
| 202 |
+
logger.warning("All chunks are already in the storage")
|
| 203 |
return
|
| 204 |
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
|
| 205 |
|
|
|
|
| 242 |
def query(self, query: str, param: QueryParam = QueryParam()):
|
| 243 |
loop = always_get_an_event_loop()
|
| 244 |
return loop.run_until_complete(self.aquery(query, param))
|
| 245 |
+
|
| 246 |
async def aquery(self, query: str, param: QueryParam = QueryParam()):
|
| 247 |
if param.mode == "local":
|
| 248 |
response = await local_query(
|
|
|
|
| 286 |
raise ValueError(f"Unknown mode {param.mode}")
|
| 287 |
await self._query_done()
|
| 288 |
return response
|
|
|
|
| 289 |
|
| 290 |
async def _query_done(self):
|
| 291 |
tasks = []
|
|
|
|
| 294 |
continue
|
| 295 |
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
| 296 |
await asyncio.gather(*tasks)
|
|
|
|
|
|
lightrag/llm.py
CHANGED
|
@@ -1,9 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
import copy
|
| 3 |
import json
|
| 4 |
-
import botocore
|
| 5 |
import aioboto3
|
| 6 |
-
import botocore.errorfactory
|
| 7 |
import numpy as np
|
| 8 |
import ollama
|
| 9 |
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
|
|
@@ -13,24 +11,34 @@ from tenacity import (
|
|
| 13 |
wait_exponential,
|
| 14 |
retry_if_exception_type,
|
| 15 |
)
|
| 16 |
-
from transformers import
|
| 17 |
import torch
|
| 18 |
from .base import BaseKVStorage
|
| 19 |
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
| 20 |
-
|
| 21 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
|
|
|
| 22 |
@retry(
|
| 23 |
stop=stop_after_attempt(3),
|
| 24 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 25 |
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
| 26 |
)
|
| 27 |
async def openai_complete_if_cache(
|
| 28 |
-
model,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
) -> str:
|
| 30 |
if api_key:
|
| 31 |
os.environ["OPENAI_API_KEY"] = api_key
|
| 32 |
|
| 33 |
-
openai_async_client =
|
|
|
|
|
|
|
| 34 |
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
| 35 |
messages = []
|
| 36 |
if system_prompt:
|
|
@@ -64,43 +72,56 @@ class BedrockError(Exception):
|
|
| 64 |
retry=retry_if_exception_type((BedrockError)),
|
| 65 |
)
|
| 66 |
async def bedrock_complete_if_cache(
|
| 67 |
-
model,
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
) -> str:
|
| 70 |
-
os.environ[
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
# Fix message history format
|
| 75 |
messages = []
|
| 76 |
for history_message in history_messages:
|
| 77 |
message = copy.copy(history_message)
|
| 78 |
-
message[
|
| 79 |
messages.append(message)
|
| 80 |
|
| 81 |
# Add user prompt
|
| 82 |
-
messages.append({
|
| 83 |
|
| 84 |
# Initialize Converse API arguments
|
| 85 |
-
args = {
|
| 86 |
-
'modelId': model,
|
| 87 |
-
'messages': messages
|
| 88 |
-
}
|
| 89 |
|
| 90 |
# Define system prompt
|
| 91 |
if system_prompt:
|
| 92 |
-
args[
|
| 93 |
|
| 94 |
# Map and set up inference parameters
|
| 95 |
inference_params_map = {
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
}
|
| 100 |
-
if
|
| 101 |
-
|
|
|
|
|
|
|
| 102 |
for param in inference_params:
|
| 103 |
-
args[
|
|
|
|
|
|
|
| 104 |
|
| 105 |
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
| 106 |
if hashing_kv is not None:
|
|
@@ -112,31 +133,33 @@ async def bedrock_complete_if_cache(
|
|
| 112 |
# Call model via Converse API
|
| 113 |
session = aioboto3.Session()
|
| 114 |
async with session.client("bedrock-runtime") as bedrock_async_client:
|
| 115 |
-
|
| 116 |
try:
|
| 117 |
response = await bedrock_async_client.converse(**args, **kwargs)
|
| 118 |
except Exception as e:
|
| 119 |
raise BedrockError(e)
|
| 120 |
|
| 121 |
if hashing_kv is not None:
|
| 122 |
-
await hashing_kv.upsert(
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
| 126 |
}
|
| 127 |
-
|
|
|
|
|
|
|
| 128 |
|
| 129 |
-
return response['output']['message']['content'][0]['text']
|
| 130 |
|
| 131 |
async def hf_model_if_cache(
|
| 132 |
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
| 133 |
) -> str:
|
| 134 |
model_name = model
|
| 135 |
-
hf_tokenizer = AutoTokenizer.from_pretrained(model_name,device_map
|
| 136 |
-
if hf_tokenizer.pad_token
|
| 137 |
# print("use eos token")
|
| 138 |
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
| 139 |
-
hf_model = AutoModelForCausalLM.from_pretrained(model_name,device_map
|
| 140 |
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
| 141 |
messages = []
|
| 142 |
if system_prompt:
|
|
@@ -149,30 +172,51 @@ async def hf_model_if_cache(
|
|
| 149 |
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
| 150 |
if if_cache_return is not None:
|
| 151 |
return if_cache_return["return"]
|
| 152 |
-
input_prompt =
|
| 153 |
try:
|
| 154 |
-
input_prompt = hf_tokenizer.apply_chat_template(
|
| 155 |
-
|
|
|
|
|
|
|
| 156 |
try:
|
| 157 |
ori_message = copy.deepcopy(messages)
|
| 158 |
-
if messages[0][
|
| 159 |
-
messages[1][
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
messages = messages[1:]
|
| 161 |
-
input_prompt = hf_tokenizer.apply_chat_template(
|
| 162 |
-
|
|
|
|
|
|
|
| 163 |
len_message = len(ori_message)
|
| 164 |
for msgid in range(len_message):
|
| 165 |
-
input_prompt =
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
|
| 170 |
if hashing_kv is not None:
|
| 171 |
-
await hashing_kv.upsert(
|
| 172 |
-
{args_hash: {"return": response_text, "model": model}}
|
| 173 |
-
)
|
| 174 |
return response_text
|
| 175 |
|
|
|
|
| 176 |
async def ollama_model_if_cache(
|
| 177 |
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
| 178 |
) -> str:
|
|
@@ -202,6 +246,7 @@ async def ollama_model_if_cache(
|
|
| 202 |
|
| 203 |
return result
|
| 204 |
|
|
|
|
| 205 |
async def gpt_4o_complete(
|
| 206 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 207 |
) -> str:
|
|
@@ -241,7 +286,7 @@ async def bedrock_complete(
|
|
| 241 |
async def hf_model_complete(
|
| 242 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 243 |
) -> str:
|
| 244 |
-
model_name = kwargs[
|
| 245 |
return await hf_model_if_cache(
|
| 246 |
model_name,
|
| 247 |
prompt,
|
|
@@ -250,10 +295,11 @@ async def hf_model_complete(
|
|
| 250 |
**kwargs,
|
| 251 |
)
|
| 252 |
|
|
|
|
| 253 |
async def ollama_model_complete(
|
| 254 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 255 |
) -> str:
|
| 256 |
-
model_name = kwargs[
|
| 257 |
return await ollama_model_if_cache(
|
| 258 |
model_name,
|
| 259 |
prompt,
|
|
@@ -262,17 +308,25 @@ async def ollama_model_complete(
|
|
| 262 |
**kwargs,
|
| 263 |
)
|
| 264 |
|
|
|
|
| 265 |
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
| 266 |
@retry(
|
| 267 |
stop=stop_after_attempt(3),
|
| 268 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 269 |
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
| 270 |
)
|
| 271 |
-
async def openai_embedding(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
if api_key:
|
| 273 |
os.environ["OPENAI_API_KEY"] = api_key
|
| 274 |
|
| 275 |
-
openai_async_client =
|
|
|
|
|
|
|
| 276 |
response = await openai_async_client.embeddings.create(
|
| 277 |
model=model, input=texts, encoding_format="float"
|
| 278 |
)
|
|
@@ -286,28 +340,37 @@ async def openai_embedding(texts: list[str], model: str = "text-embedding-3-smal
|
|
| 286 |
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
|
| 287 |
# )
|
| 288 |
async def bedrock_embedding(
|
| 289 |
-
texts: list[str],
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
|
| 295 |
session = aioboto3.Session()
|
| 296 |
async with session.client("bedrock-runtime") as bedrock_async_client:
|
| 297 |
-
|
| 298 |
if (model_provider := model.split(".")[0]) == "amazon":
|
| 299 |
embed_texts = []
|
| 300 |
for text in texts:
|
| 301 |
if "v2" in model:
|
| 302 |
-
body = json.dumps(
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
|
|
|
|
|
|
| 307 |
elif "v1" in model:
|
| 308 |
-
body = json.dumps({
|
| 309 |
-
'inputText': text
|
| 310 |
-
})
|
| 311 |
else:
|
| 312 |
raise ValueError(f"Model {model} is not supported!")
|
| 313 |
|
|
@@ -315,29 +378,27 @@ async def bedrock_embedding(
|
|
| 315 |
modelId=model,
|
| 316 |
body=body,
|
| 317 |
accept="application/json",
|
| 318 |
-
contentType="application/json"
|
| 319 |
)
|
| 320 |
|
| 321 |
-
response_body = await response.get(
|
| 322 |
|
| 323 |
-
embed_texts.append(response_body[
|
| 324 |
elif model_provider == "cohere":
|
| 325 |
-
body = json.dumps(
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
'truncate': "NONE"
|
| 329 |
-
})
|
| 330 |
|
| 331 |
response = await bedrock_async_client.invoke_model(
|
| 332 |
model=model,
|
| 333 |
body=body,
|
| 334 |
accept="application/json",
|
| 335 |
-
contentType="application/json"
|
| 336 |
)
|
| 337 |
|
| 338 |
-
response_body = json.loads(response.get(
|
| 339 |
|
| 340 |
-
embed_texts = response_body[
|
| 341 |
else:
|
| 342 |
raise ValueError(f"Model provider '{model_provider}' is not supported!")
|
| 343 |
|
|
@@ -345,12 +406,15 @@ async def bedrock_embedding(
|
|
| 345 |
|
| 346 |
|
| 347 |
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
| 348 |
-
input_ids = tokenizer(
|
|
|
|
|
|
|
| 349 |
with torch.no_grad():
|
| 350 |
outputs = embed_model(input_ids)
|
| 351 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
| 352 |
return embeddings.detach().numpy()
|
| 353 |
|
|
|
|
| 354 |
async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
|
| 355 |
embed_text = []
|
| 356 |
for text in texts:
|
|
@@ -359,11 +423,12 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
|
|
| 359 |
|
| 360 |
return embed_text
|
| 361 |
|
|
|
|
| 362 |
if __name__ == "__main__":
|
| 363 |
import asyncio
|
| 364 |
|
| 365 |
async def main():
|
| 366 |
-
result = await gpt_4o_mini_complete(
|
| 367 |
print(result)
|
| 368 |
|
| 369 |
asyncio.run(main())
|
|
|
|
| 1 |
import os
|
| 2 |
import copy
|
| 3 |
import json
|
|
|
|
| 4 |
import aioboto3
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
import ollama
|
| 7 |
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
|
|
|
|
| 11 |
wait_exponential,
|
| 12 |
retry_if_exception_type,
|
| 13 |
)
|
| 14 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 15 |
import torch
|
| 16 |
from .base import BaseKVStorage
|
| 17 |
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
| 18 |
+
|
| 19 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
@retry(
|
| 23 |
stop=stop_after_attempt(3),
|
| 24 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 25 |
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
| 26 |
)
|
| 27 |
async def openai_complete_if_cache(
|
| 28 |
+
model,
|
| 29 |
+
prompt,
|
| 30 |
+
system_prompt=None,
|
| 31 |
+
history_messages=[],
|
| 32 |
+
base_url=None,
|
| 33 |
+
api_key=None,
|
| 34 |
+
**kwargs,
|
| 35 |
) -> str:
|
| 36 |
if api_key:
|
| 37 |
os.environ["OPENAI_API_KEY"] = api_key
|
| 38 |
|
| 39 |
+
openai_async_client = (
|
| 40 |
+
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
| 41 |
+
)
|
| 42 |
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
| 43 |
messages = []
|
| 44 |
if system_prompt:
|
|
|
|
| 72 |
retry=retry_if_exception_type((BedrockError)),
|
| 73 |
)
|
| 74 |
async def bedrock_complete_if_cache(
|
| 75 |
+
model,
|
| 76 |
+
prompt,
|
| 77 |
+
system_prompt=None,
|
| 78 |
+
history_messages=[],
|
| 79 |
+
aws_access_key_id=None,
|
| 80 |
+
aws_secret_access_key=None,
|
| 81 |
+
aws_session_token=None,
|
| 82 |
+
**kwargs,
|
| 83 |
) -> str:
|
| 84 |
+
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
| 85 |
+
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
| 86 |
+
)
|
| 87 |
+
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
|
| 88 |
+
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
|
| 89 |
+
)
|
| 90 |
+
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
| 91 |
+
"AWS_SESSION_TOKEN", aws_session_token
|
| 92 |
+
)
|
| 93 |
|
| 94 |
# Fix message history format
|
| 95 |
messages = []
|
| 96 |
for history_message in history_messages:
|
| 97 |
message = copy.copy(history_message)
|
| 98 |
+
message["content"] = [{"text": message["content"]}]
|
| 99 |
messages.append(message)
|
| 100 |
|
| 101 |
# Add user prompt
|
| 102 |
+
messages.append({"role": "user", "content": [{"text": prompt}]})
|
| 103 |
|
| 104 |
# Initialize Converse API arguments
|
| 105 |
+
args = {"modelId": model, "messages": messages}
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
# Define system prompt
|
| 108 |
if system_prompt:
|
| 109 |
+
args["system"] = [{"text": system_prompt}]
|
| 110 |
|
| 111 |
# Map and set up inference parameters
|
| 112 |
inference_params_map = {
|
| 113 |
+
"max_tokens": "maxTokens",
|
| 114 |
+
"top_p": "topP",
|
| 115 |
+
"stop_sequences": "stopSequences",
|
| 116 |
}
|
| 117 |
+
if inference_params := list(
|
| 118 |
+
set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
|
| 119 |
+
):
|
| 120 |
+
args["inferenceConfig"] = {}
|
| 121 |
for param in inference_params:
|
| 122 |
+
args["inferenceConfig"][inference_params_map.get(param, param)] = (
|
| 123 |
+
kwargs.pop(param)
|
| 124 |
+
)
|
| 125 |
|
| 126 |
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
| 127 |
if hashing_kv is not None:
|
|
|
|
| 133 |
# Call model via Converse API
|
| 134 |
session = aioboto3.Session()
|
| 135 |
async with session.client("bedrock-runtime") as bedrock_async_client:
|
|
|
|
| 136 |
try:
|
| 137 |
response = await bedrock_async_client.converse(**args, **kwargs)
|
| 138 |
except Exception as e:
|
| 139 |
raise BedrockError(e)
|
| 140 |
|
| 141 |
if hashing_kv is not None:
|
| 142 |
+
await hashing_kv.upsert(
|
| 143 |
+
{
|
| 144 |
+
args_hash: {
|
| 145 |
+
"return": response["output"]["message"]["content"][0]["text"],
|
| 146 |
+
"model": model,
|
| 147 |
+
}
|
| 148 |
}
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
return response["output"]["message"]["content"][0]["text"]
|
| 152 |
|
|
|
|
| 153 |
|
| 154 |
async def hf_model_if_cache(
|
| 155 |
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
| 156 |
) -> str:
|
| 157 |
model_name = model
|
| 158 |
+
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
|
| 159 |
+
if hf_tokenizer.pad_token is None:
|
| 160 |
# print("use eos token")
|
| 161 |
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
| 162 |
+
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
|
| 163 |
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
| 164 |
messages = []
|
| 165 |
if system_prompt:
|
|
|
|
| 172 |
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
| 173 |
if if_cache_return is not None:
|
| 174 |
return if_cache_return["return"]
|
| 175 |
+
input_prompt = ""
|
| 176 |
try:
|
| 177 |
+
input_prompt = hf_tokenizer.apply_chat_template(
|
| 178 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 179 |
+
)
|
| 180 |
+
except Exception:
|
| 181 |
try:
|
| 182 |
ori_message = copy.deepcopy(messages)
|
| 183 |
+
if messages[0]["role"] == "system":
|
| 184 |
+
messages[1]["content"] = (
|
| 185 |
+
"<system>"
|
| 186 |
+
+ messages[0]["content"]
|
| 187 |
+
+ "</system>\n"
|
| 188 |
+
+ messages[1]["content"]
|
| 189 |
+
)
|
| 190 |
messages = messages[1:]
|
| 191 |
+
input_prompt = hf_tokenizer.apply_chat_template(
|
| 192 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 193 |
+
)
|
| 194 |
+
except Exception:
|
| 195 |
len_message = len(ori_message)
|
| 196 |
for msgid in range(len_message):
|
| 197 |
+
input_prompt = (
|
| 198 |
+
input_prompt
|
| 199 |
+
+ "<"
|
| 200 |
+
+ ori_message[msgid]["role"]
|
| 201 |
+
+ ">"
|
| 202 |
+
+ ori_message[msgid]["content"]
|
| 203 |
+
+ "</"
|
| 204 |
+
+ ori_message[msgid]["role"]
|
| 205 |
+
+ ">\n"
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
input_ids = hf_tokenizer(
|
| 209 |
+
input_prompt, return_tensors="pt", padding=True, truncation=True
|
| 210 |
+
).to("cuda")
|
| 211 |
+
output = hf_model.generate(
|
| 212 |
+
**input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True
|
| 213 |
+
)
|
| 214 |
response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
|
| 215 |
if hashing_kv is not None:
|
| 216 |
+
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
|
|
|
|
|
|
|
| 217 |
return response_text
|
| 218 |
|
| 219 |
+
|
| 220 |
async def ollama_model_if_cache(
|
| 221 |
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
| 222 |
) -> str:
|
|
|
|
| 246 |
|
| 247 |
return result
|
| 248 |
|
| 249 |
+
|
| 250 |
async def gpt_4o_complete(
|
| 251 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 252 |
) -> str:
|
|
|
|
| 286 |
async def hf_model_complete(
|
| 287 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 288 |
) -> str:
|
| 289 |
+
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
| 290 |
return await hf_model_if_cache(
|
| 291 |
model_name,
|
| 292 |
prompt,
|
|
|
|
| 295 |
**kwargs,
|
| 296 |
)
|
| 297 |
|
| 298 |
+
|
| 299 |
async def ollama_model_complete(
|
| 300 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
| 301 |
) -> str:
|
| 302 |
+
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
| 303 |
return await ollama_model_if_cache(
|
| 304 |
model_name,
|
| 305 |
prompt,
|
|
|
|
| 308 |
**kwargs,
|
| 309 |
)
|
| 310 |
|
| 311 |
+
|
| 312 |
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
| 313 |
@retry(
|
| 314 |
stop=stop_after_attempt(3),
|
| 315 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 316 |
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
| 317 |
)
|
| 318 |
+
async def openai_embedding(
|
| 319 |
+
texts: list[str],
|
| 320 |
+
model: str = "text-embedding-3-small",
|
| 321 |
+
base_url: str = None,
|
| 322 |
+
api_key: str = None,
|
| 323 |
+
) -> np.ndarray:
|
| 324 |
if api_key:
|
| 325 |
os.environ["OPENAI_API_KEY"] = api_key
|
| 326 |
|
| 327 |
+
openai_async_client = (
|
| 328 |
+
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
| 329 |
+
)
|
| 330 |
response = await openai_async_client.embeddings.create(
|
| 331 |
model=model, input=texts, encoding_format="float"
|
| 332 |
)
|
|
|
|
| 340 |
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
|
| 341 |
# )
|
| 342 |
async def bedrock_embedding(
|
| 343 |
+
texts: list[str],
|
| 344 |
+
model: str = "amazon.titan-embed-text-v2:0",
|
| 345 |
+
aws_access_key_id=None,
|
| 346 |
+
aws_secret_access_key=None,
|
| 347 |
+
aws_session_token=None,
|
| 348 |
+
) -> np.ndarray:
|
| 349 |
+
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
| 350 |
+
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
| 351 |
+
)
|
| 352 |
+
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
|
| 353 |
+
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
|
| 354 |
+
)
|
| 355 |
+
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
| 356 |
+
"AWS_SESSION_TOKEN", aws_session_token
|
| 357 |
+
)
|
| 358 |
|
| 359 |
session = aioboto3.Session()
|
| 360 |
async with session.client("bedrock-runtime") as bedrock_async_client:
|
|
|
|
| 361 |
if (model_provider := model.split(".")[0]) == "amazon":
|
| 362 |
embed_texts = []
|
| 363 |
for text in texts:
|
| 364 |
if "v2" in model:
|
| 365 |
+
body = json.dumps(
|
| 366 |
+
{
|
| 367 |
+
"inputText": text,
|
| 368 |
+
# 'dimensions': embedding_dim,
|
| 369 |
+
"embeddingTypes": ["float"],
|
| 370 |
+
}
|
| 371 |
+
)
|
| 372 |
elif "v1" in model:
|
| 373 |
+
body = json.dumps({"inputText": text})
|
|
|
|
|
|
|
| 374 |
else:
|
| 375 |
raise ValueError(f"Model {model} is not supported!")
|
| 376 |
|
|
|
|
| 378 |
modelId=model,
|
| 379 |
body=body,
|
| 380 |
accept="application/json",
|
| 381 |
+
contentType="application/json",
|
| 382 |
)
|
| 383 |
|
| 384 |
+
response_body = await response.get("body").json()
|
| 385 |
|
| 386 |
+
embed_texts.append(response_body["embedding"])
|
| 387 |
elif model_provider == "cohere":
|
| 388 |
+
body = json.dumps(
|
| 389 |
+
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
|
| 390 |
+
)
|
|
|
|
|
|
|
| 391 |
|
| 392 |
response = await bedrock_async_client.invoke_model(
|
| 393 |
model=model,
|
| 394 |
body=body,
|
| 395 |
accept="application/json",
|
| 396 |
+
contentType="application/json",
|
| 397 |
)
|
| 398 |
|
| 399 |
+
response_body = json.loads(response.get("body").read())
|
| 400 |
|
| 401 |
+
embed_texts = response_body["embeddings"]
|
| 402 |
else:
|
| 403 |
raise ValueError(f"Model provider '{model_provider}' is not supported!")
|
| 404 |
|
|
|
|
| 406 |
|
| 407 |
|
| 408 |
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
| 409 |
+
input_ids = tokenizer(
|
| 410 |
+
texts, return_tensors="pt", padding=True, truncation=True
|
| 411 |
+
).input_ids
|
| 412 |
with torch.no_grad():
|
| 413 |
outputs = embed_model(input_ids)
|
| 414 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
| 415 |
return embeddings.detach().numpy()
|
| 416 |
|
| 417 |
+
|
| 418 |
async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
|
| 419 |
embed_text = []
|
| 420 |
for text in texts:
|
|
|
|
| 423 |
|
| 424 |
return embed_text
|
| 425 |
|
| 426 |
+
|
| 427 |
if __name__ == "__main__":
|
| 428 |
import asyncio
|
| 429 |
|
| 430 |
async def main():
|
| 431 |
+
result = await gpt_4o_mini_complete("How are you?")
|
| 432 |
print(result)
|
| 433 |
|
| 434 |
asyncio.run(main())
|
lightrag/operate.py
CHANGED
|
@@ -25,6 +25,7 @@ from .base import (
|
|
| 25 |
)
|
| 26 |
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
| 27 |
|
|
|
|
| 28 |
def chunking_by_token_size(
|
| 29 |
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
|
| 30 |
):
|
|
@@ -45,6 +46,7 @@ def chunking_by_token_size(
|
|
| 45 |
)
|
| 46 |
return results
|
| 47 |
|
|
|
|
| 48 |
async def _handle_entity_relation_summary(
|
| 49 |
entity_or_relation_name: str,
|
| 50 |
description: str,
|
|
@@ -229,9 +231,10 @@ async def _merge_edges_then_upsert(
|
|
| 229 |
description=description,
|
| 230 |
keywords=keywords,
|
| 231 |
)
|
| 232 |
-
|
| 233 |
return edge_data
|
| 234 |
|
|
|
|
| 235 |
async def extract_entities(
|
| 236 |
chunks: dict[str, TextChunkSchema],
|
| 237 |
knwoledge_graph_inst: BaseGraphStorage,
|
|
@@ -352,7 +355,9 @@ async def extract_entities(
|
|
| 352 |
logger.warning("Didn't extract any entities, maybe your LLM is not working")
|
| 353 |
return None
|
| 354 |
if not len(all_relationships_data):
|
| 355 |
-
logger.warning(
|
|
|
|
|
|
|
| 356 |
return None
|
| 357 |
|
| 358 |
if entity_vdb is not None:
|
|
@@ -370,7 +375,10 @@ async def extract_entities(
|
|
| 370 |
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
| 371 |
"src_id": dp["src_id"],
|
| 372 |
"tgt_id": dp["tgt_id"],
|
| 373 |
-
"content": dp["keywords"]
|
|
|
|
|
|
|
|
|
|
| 374 |
}
|
| 375 |
for dp in all_relationships_data
|
| 376 |
}
|
|
@@ -378,6 +386,7 @@ async def extract_entities(
|
|
| 378 |
|
| 379 |
return knwoledge_graph_inst
|
| 380 |
|
|
|
|
| 381 |
async def local_query(
|
| 382 |
query,
|
| 383 |
knowledge_graph_inst: BaseGraphStorage,
|
|
@@ -393,19 +402,24 @@ async def local_query(
|
|
| 393 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
| 394 |
kw_prompt = kw_prompt_temp.format(query=query)
|
| 395 |
result = await use_model_func(kw_prompt)
|
| 396 |
-
|
| 397 |
try:
|
| 398 |
keywords_data = json.loads(result)
|
| 399 |
keywords = keywords_data.get("low_level_keywords", [])
|
| 400 |
-
keywords =
|
| 401 |
-
except json.JSONDecodeError
|
| 402 |
try:
|
| 403 |
-
result =
|
| 404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
|
| 406 |
keywords_data = json.loads(result)
|
| 407 |
keywords = keywords_data.get("low_level_keywords", [])
|
| 408 |
-
keywords =
|
| 409 |
# Handle parsing error
|
| 410 |
except json.JSONDecodeError as e:
|
| 411 |
print(f"JSON parsing error: {e}")
|
|
@@ -430,11 +444,20 @@ async def local_query(
|
|
| 430 |
query,
|
| 431 |
system_prompt=sys_prompt,
|
| 432 |
)
|
| 433 |
-
if len(response)>len(sys_prompt):
|
| 434 |
-
response =
|
| 435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
return response
|
| 437 |
|
|
|
|
| 438 |
async def _build_local_query_context(
|
| 439 |
query,
|
| 440 |
knowledge_graph_inst: BaseGraphStorage,
|
|
@@ -516,6 +539,7 @@ async def _build_local_query_context(
|
|
| 516 |
```
|
| 517 |
"""
|
| 518 |
|
|
|
|
| 519 |
async def _find_most_related_text_unit_from_entities(
|
| 520 |
node_datas: list[dict],
|
| 521 |
query_param: QueryParam,
|
|
@@ -576,6 +600,7 @@ async def _find_most_related_text_unit_from_entities(
|
|
| 576 |
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
|
| 577 |
return all_text_units
|
| 578 |
|
|
|
|
| 579 |
async def _find_most_related_edges_from_entities(
|
| 580 |
node_datas: list[dict],
|
| 581 |
query_param: QueryParam,
|
|
@@ -609,6 +634,7 @@ async def _find_most_related_edges_from_entities(
|
|
| 609 |
)
|
| 610 |
return all_edges_data
|
| 611 |
|
|
|
|
| 612 |
async def global_query(
|
| 613 |
query,
|
| 614 |
knowledge_graph_inst: BaseGraphStorage,
|
|
@@ -624,20 +650,25 @@ async def global_query(
|
|
| 624 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
| 625 |
kw_prompt = kw_prompt_temp.format(query=query)
|
| 626 |
result = await use_model_func(kw_prompt)
|
| 627 |
-
|
| 628 |
try:
|
| 629 |
keywords_data = json.loads(result)
|
| 630 |
keywords = keywords_data.get("high_level_keywords", [])
|
| 631 |
-
keywords =
|
| 632 |
-
except json.JSONDecodeError
|
| 633 |
try:
|
| 634 |
-
result =
|
| 635 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 636 |
|
| 637 |
keywords_data = json.loads(result)
|
| 638 |
keywords = keywords_data.get("high_level_keywords", [])
|
| 639 |
-
keywords =
|
| 640 |
-
|
| 641 |
except json.JSONDecodeError as e:
|
| 642 |
# Handle parsing error
|
| 643 |
print(f"JSON parsing error: {e}")
|
|
@@ -651,12 +682,12 @@ async def global_query(
|
|
| 651 |
text_chunks_db,
|
| 652 |
query_param,
|
| 653 |
)
|
| 654 |
-
|
| 655 |
if query_param.only_need_context:
|
| 656 |
return context
|
| 657 |
if context is None:
|
| 658 |
return PROMPTS["fail_response"]
|
| 659 |
-
|
| 660 |
sys_prompt_temp = PROMPTS["rag_response"]
|
| 661 |
sys_prompt = sys_prompt_temp.format(
|
| 662 |
context_data=context, response_type=query_param.response_type
|
|
@@ -665,11 +696,20 @@ async def global_query(
|
|
| 665 |
query,
|
| 666 |
system_prompt=sys_prompt,
|
| 667 |
)
|
| 668 |
-
if len(response)>len(sys_prompt):
|
| 669 |
-
response =
|
| 670 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 671 |
return response
|
| 672 |
|
|
|
|
| 673 |
async def _build_global_query_context(
|
| 674 |
keywords,
|
| 675 |
knowledge_graph_inst: BaseGraphStorage,
|
|
@@ -679,14 +719,14 @@ async def _build_global_query_context(
|
|
| 679 |
query_param: QueryParam,
|
| 680 |
):
|
| 681 |
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
| 682 |
-
|
| 683 |
if not len(results):
|
| 684 |
return None
|
| 685 |
-
|
| 686 |
edge_datas = await asyncio.gather(
|
| 687 |
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
|
| 688 |
)
|
| 689 |
-
|
| 690 |
if not all([n is not None for n in edge_datas]):
|
| 691 |
logger.warning("Some edges are missing, maybe the storage is damaged")
|
| 692 |
edge_degree = await asyncio.gather(
|
|
@@ -765,6 +805,7 @@ async def _build_global_query_context(
|
|
| 765 |
```
|
| 766 |
"""
|
| 767 |
|
|
|
|
| 768 |
async def _find_most_related_entities_from_relationships(
|
| 769 |
edge_datas: list[dict],
|
| 770 |
query_param: QueryParam,
|
|
@@ -774,7 +815,7 @@ async def _find_most_related_entities_from_relationships(
|
|
| 774 |
for e in edge_datas:
|
| 775 |
entity_names.add(e["src_id"])
|
| 776 |
entity_names.add(e["tgt_id"])
|
| 777 |
-
|
| 778 |
node_datas = await asyncio.gather(
|
| 779 |
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
|
| 780 |
)
|
|
@@ -795,13 +836,13 @@ async def _find_most_related_entities_from_relationships(
|
|
| 795 |
|
| 796 |
return node_datas
|
| 797 |
|
|
|
|
| 798 |
async def _find_related_text_unit_from_relationships(
|
| 799 |
edge_datas: list[dict],
|
| 800 |
query_param: QueryParam,
|
| 801 |
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
| 802 |
knowledge_graph_inst: BaseGraphStorage,
|
| 803 |
):
|
| 804 |
-
|
| 805 |
text_units = [
|
| 806 |
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
| 807 |
for dp in edge_datas
|
|
@@ -816,15 +857,13 @@ async def _find_related_text_unit_from_relationships(
|
|
| 816 |
"data": await text_chunks_db.get_by_id(c_id),
|
| 817 |
"order": index,
|
| 818 |
}
|
| 819 |
-
|
| 820 |
if any([v is None for v in all_text_units_lookup.values()]):
|
| 821 |
logger.warning("Text chunks are missing, maybe the storage is damaged")
|
| 822 |
all_text_units = [
|
| 823 |
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
|
| 824 |
]
|
| 825 |
-
all_text_units = sorted(
|
| 826 |
-
all_text_units, key=lambda x: x["order"]
|
| 827 |
-
)
|
| 828 |
all_text_units = truncate_list_by_token_size(
|
| 829 |
all_text_units,
|
| 830 |
key=lambda x: x["data"]["content"],
|
|
@@ -834,6 +873,7 @@ async def _find_related_text_unit_from_relationships(
|
|
| 834 |
|
| 835 |
return all_text_units
|
| 836 |
|
|
|
|
| 837 |
async def hybrid_query(
|
| 838 |
query,
|
| 839 |
knowledge_graph_inst: BaseGraphStorage,
|
|
@@ -849,24 +889,29 @@ async def hybrid_query(
|
|
| 849 |
|
| 850 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
| 851 |
kw_prompt = kw_prompt_temp.format(query=query)
|
| 852 |
-
|
| 853 |
result = await use_model_func(kw_prompt)
|
| 854 |
try:
|
| 855 |
keywords_data = json.loads(result)
|
| 856 |
hl_keywords = keywords_data.get("high_level_keywords", [])
|
| 857 |
ll_keywords = keywords_data.get("low_level_keywords", [])
|
| 858 |
-
hl_keywords =
|
| 859 |
-
ll_keywords =
|
| 860 |
-
except json.JSONDecodeError
|
| 861 |
try:
|
| 862 |
-
result =
|
| 863 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 864 |
|
| 865 |
keywords_data = json.loads(result)
|
| 866 |
hl_keywords = keywords_data.get("high_level_keywords", [])
|
| 867 |
ll_keywords = keywords_data.get("low_level_keywords", [])
|
| 868 |
-
hl_keywords =
|
| 869 |
-
ll_keywords =
|
| 870 |
# Handle parsing error
|
| 871 |
except json.JSONDecodeError as e:
|
| 872 |
print(f"JSON parsing error: {e}")
|
|
@@ -897,7 +942,7 @@ async def hybrid_query(
|
|
| 897 |
return context
|
| 898 |
if context is None:
|
| 899 |
return PROMPTS["fail_response"]
|
| 900 |
-
|
| 901 |
sys_prompt_temp = PROMPTS["rag_response"]
|
| 902 |
sys_prompt = sys_prompt_temp.format(
|
| 903 |
context_data=context, response_type=query_param.response_type
|
|
@@ -906,53 +951,78 @@ async def hybrid_query(
|
|
| 906 |
query,
|
| 907 |
system_prompt=sys_prompt,
|
| 908 |
)
|
| 909 |
-
if len(response)>len(sys_prompt):
|
| 910 |
-
response =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 911 |
return response
|
| 912 |
|
|
|
|
| 913 |
def combine_contexts(high_level_context, low_level_context):
|
| 914 |
# Function to extract entities, relationships, and sources from context strings
|
| 915 |
|
| 916 |
def extract_sections(context):
|
| 917 |
-
entities_match = re.search(
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 925 |
return entities, relationships, sources
|
| 926 |
-
|
| 927 |
# Extract sections from both contexts
|
| 928 |
|
| 929 |
-
if high_level_context
|
| 930 |
-
warnings.warn(
|
| 931 |
-
|
|
|
|
|
|
|
| 932 |
else:
|
| 933 |
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
|
| 934 |
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
|
|
|
| 939 |
else:
|
| 940 |
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
|
| 941 |
|
| 942 |
-
|
| 943 |
-
|
| 944 |
# Combine and deduplicate the entities
|
| 945 |
-
combined_entities_set = set(
|
| 946 |
-
|
| 947 |
-
|
|
|
|
|
|
|
| 948 |
# Combine and deduplicate the relationships
|
| 949 |
-
combined_relationships_set = set(
|
| 950 |
-
|
| 951 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 952 |
# Combine and deduplicate the sources
|
| 953 |
-
combined_sources_set = set(
|
| 954 |
-
|
| 955 |
-
|
|
|
|
|
|
|
| 956 |
# Format the combined context
|
| 957 |
return f"""
|
| 958 |
-----Entities-----
|
|
@@ -964,6 +1034,7 @@ def combine_contexts(high_level_context, low_level_context):
|
|
| 964 |
{combined_sources}
|
| 965 |
"""
|
| 966 |
|
|
|
|
| 967 |
async def naive_query(
|
| 968 |
query,
|
| 969 |
chunks_vdb: BaseVectorStorage,
|
|
@@ -996,8 +1067,16 @@ async def naive_query(
|
|
| 996 |
system_prompt=sys_prompt,
|
| 997 |
)
|
| 998 |
|
| 999 |
-
if len(response)>len(sys_prompt):
|
| 1000 |
-
response =
|
| 1001 |
-
|
| 1002 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1003 |
|
|
|
|
|
|
| 25 |
)
|
| 26 |
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
| 27 |
|
| 28 |
+
|
| 29 |
def chunking_by_token_size(
|
| 30 |
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
|
| 31 |
):
|
|
|
|
| 46 |
)
|
| 47 |
return results
|
| 48 |
|
| 49 |
+
|
| 50 |
async def _handle_entity_relation_summary(
|
| 51 |
entity_or_relation_name: str,
|
| 52 |
description: str,
|
|
|
|
| 231 |
description=description,
|
| 232 |
keywords=keywords,
|
| 233 |
)
|
| 234 |
+
|
| 235 |
return edge_data
|
| 236 |
|
| 237 |
+
|
| 238 |
async def extract_entities(
|
| 239 |
chunks: dict[str, TextChunkSchema],
|
| 240 |
knwoledge_graph_inst: BaseGraphStorage,
|
|
|
|
| 355 |
logger.warning("Didn't extract any entities, maybe your LLM is not working")
|
| 356 |
return None
|
| 357 |
if not len(all_relationships_data):
|
| 358 |
+
logger.warning(
|
| 359 |
+
"Didn't extract any relationships, maybe your LLM is not working"
|
| 360 |
+
)
|
| 361 |
return None
|
| 362 |
|
| 363 |
if entity_vdb is not None:
|
|
|
|
| 375 |
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
| 376 |
"src_id": dp["src_id"],
|
| 377 |
"tgt_id": dp["tgt_id"],
|
| 378 |
+
"content": dp["keywords"]
|
| 379 |
+
+ dp["src_id"]
|
| 380 |
+
+ dp["tgt_id"]
|
| 381 |
+
+ dp["description"],
|
| 382 |
}
|
| 383 |
for dp in all_relationships_data
|
| 384 |
}
|
|
|
|
| 386 |
|
| 387 |
return knwoledge_graph_inst
|
| 388 |
|
| 389 |
+
|
| 390 |
async def local_query(
|
| 391 |
query,
|
| 392 |
knowledge_graph_inst: BaseGraphStorage,
|
|
|
|
| 402 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
| 403 |
kw_prompt = kw_prompt_temp.format(query=query)
|
| 404 |
result = await use_model_func(kw_prompt)
|
| 405 |
+
|
| 406 |
try:
|
| 407 |
keywords_data = json.loads(result)
|
| 408 |
keywords = keywords_data.get("low_level_keywords", [])
|
| 409 |
+
keywords = ", ".join(keywords)
|
| 410 |
+
except json.JSONDecodeError:
|
| 411 |
try:
|
| 412 |
+
result = (
|
| 413 |
+
result.replace(kw_prompt[:-1], "")
|
| 414 |
+
.replace("user", "")
|
| 415 |
+
.replace("model", "")
|
| 416 |
+
.strip()
|
| 417 |
+
)
|
| 418 |
+
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
| 419 |
|
| 420 |
keywords_data = json.loads(result)
|
| 421 |
keywords = keywords_data.get("low_level_keywords", [])
|
| 422 |
+
keywords = ", ".join(keywords)
|
| 423 |
# Handle parsing error
|
| 424 |
except json.JSONDecodeError as e:
|
| 425 |
print(f"JSON parsing error: {e}")
|
|
|
|
| 444 |
query,
|
| 445 |
system_prompt=sys_prompt,
|
| 446 |
)
|
| 447 |
+
if len(response) > len(sys_prompt):
|
| 448 |
+
response = (
|
| 449 |
+
response.replace(sys_prompt, "")
|
| 450 |
+
.replace("user", "")
|
| 451 |
+
.replace("model", "")
|
| 452 |
+
.replace(query, "")
|
| 453 |
+
.replace("<system>", "")
|
| 454 |
+
.replace("</system>", "")
|
| 455 |
+
.strip()
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
return response
|
| 459 |
|
| 460 |
+
|
| 461 |
async def _build_local_query_context(
|
| 462 |
query,
|
| 463 |
knowledge_graph_inst: BaseGraphStorage,
|
|
|
|
| 539 |
```
|
| 540 |
"""
|
| 541 |
|
| 542 |
+
|
| 543 |
async def _find_most_related_text_unit_from_entities(
|
| 544 |
node_datas: list[dict],
|
| 545 |
query_param: QueryParam,
|
|
|
|
| 600 |
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
|
| 601 |
return all_text_units
|
| 602 |
|
| 603 |
+
|
| 604 |
async def _find_most_related_edges_from_entities(
|
| 605 |
node_datas: list[dict],
|
| 606 |
query_param: QueryParam,
|
|
|
|
| 634 |
)
|
| 635 |
return all_edges_data
|
| 636 |
|
| 637 |
+
|
| 638 |
async def global_query(
|
| 639 |
query,
|
| 640 |
knowledge_graph_inst: BaseGraphStorage,
|
|
|
|
| 650 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
| 651 |
kw_prompt = kw_prompt_temp.format(query=query)
|
| 652 |
result = await use_model_func(kw_prompt)
|
| 653 |
+
|
| 654 |
try:
|
| 655 |
keywords_data = json.loads(result)
|
| 656 |
keywords = keywords_data.get("high_level_keywords", [])
|
| 657 |
+
keywords = ", ".join(keywords)
|
| 658 |
+
except json.JSONDecodeError:
|
| 659 |
try:
|
| 660 |
+
result = (
|
| 661 |
+
result.replace(kw_prompt[:-1], "")
|
| 662 |
+
.replace("user", "")
|
| 663 |
+
.replace("model", "")
|
| 664 |
+
.strip()
|
| 665 |
+
)
|
| 666 |
+
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
| 667 |
|
| 668 |
keywords_data = json.loads(result)
|
| 669 |
keywords = keywords_data.get("high_level_keywords", [])
|
| 670 |
+
keywords = ", ".join(keywords)
|
| 671 |
+
|
| 672 |
except json.JSONDecodeError as e:
|
| 673 |
# Handle parsing error
|
| 674 |
print(f"JSON parsing error: {e}")
|
|
|
|
| 682 |
text_chunks_db,
|
| 683 |
query_param,
|
| 684 |
)
|
| 685 |
+
|
| 686 |
if query_param.only_need_context:
|
| 687 |
return context
|
| 688 |
if context is None:
|
| 689 |
return PROMPTS["fail_response"]
|
| 690 |
+
|
| 691 |
sys_prompt_temp = PROMPTS["rag_response"]
|
| 692 |
sys_prompt = sys_prompt_temp.format(
|
| 693 |
context_data=context, response_type=query_param.response_type
|
|
|
|
| 696 |
query,
|
| 697 |
system_prompt=sys_prompt,
|
| 698 |
)
|
| 699 |
+
if len(response) > len(sys_prompt):
|
| 700 |
+
response = (
|
| 701 |
+
response.replace(sys_prompt, "")
|
| 702 |
+
.replace("user", "")
|
| 703 |
+
.replace("model", "")
|
| 704 |
+
.replace(query, "")
|
| 705 |
+
.replace("<system>", "")
|
| 706 |
+
.replace("</system>", "")
|
| 707 |
+
.strip()
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
return response
|
| 711 |
|
| 712 |
+
|
| 713 |
async def _build_global_query_context(
|
| 714 |
keywords,
|
| 715 |
knowledge_graph_inst: BaseGraphStorage,
|
|
|
|
| 719 |
query_param: QueryParam,
|
| 720 |
):
|
| 721 |
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
| 722 |
+
|
| 723 |
if not len(results):
|
| 724 |
return None
|
| 725 |
+
|
| 726 |
edge_datas = await asyncio.gather(
|
| 727 |
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
|
| 728 |
)
|
| 729 |
+
|
| 730 |
if not all([n is not None for n in edge_datas]):
|
| 731 |
logger.warning("Some edges are missing, maybe the storage is damaged")
|
| 732 |
edge_degree = await asyncio.gather(
|
|
|
|
| 805 |
```
|
| 806 |
"""
|
| 807 |
|
| 808 |
+
|
| 809 |
async def _find_most_related_entities_from_relationships(
|
| 810 |
edge_datas: list[dict],
|
| 811 |
query_param: QueryParam,
|
|
|
|
| 815 |
for e in edge_datas:
|
| 816 |
entity_names.add(e["src_id"])
|
| 817 |
entity_names.add(e["tgt_id"])
|
| 818 |
+
|
| 819 |
node_datas = await asyncio.gather(
|
| 820 |
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
|
| 821 |
)
|
|
|
|
| 836 |
|
| 837 |
return node_datas
|
| 838 |
|
| 839 |
+
|
| 840 |
async def _find_related_text_unit_from_relationships(
|
| 841 |
edge_datas: list[dict],
|
| 842 |
query_param: QueryParam,
|
| 843 |
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
| 844 |
knowledge_graph_inst: BaseGraphStorage,
|
| 845 |
):
|
|
|
|
| 846 |
text_units = [
|
| 847 |
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
| 848 |
for dp in edge_datas
|
|
|
|
| 857 |
"data": await text_chunks_db.get_by_id(c_id),
|
| 858 |
"order": index,
|
| 859 |
}
|
| 860 |
+
|
| 861 |
if any([v is None for v in all_text_units_lookup.values()]):
|
| 862 |
logger.warning("Text chunks are missing, maybe the storage is damaged")
|
| 863 |
all_text_units = [
|
| 864 |
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
|
| 865 |
]
|
| 866 |
+
all_text_units = sorted(all_text_units, key=lambda x: x["order"])
|
|
|
|
|
|
|
| 867 |
all_text_units = truncate_list_by_token_size(
|
| 868 |
all_text_units,
|
| 869 |
key=lambda x: x["data"]["content"],
|
|
|
|
| 873 |
|
| 874 |
return all_text_units
|
| 875 |
|
| 876 |
+
|
| 877 |
async def hybrid_query(
|
| 878 |
query,
|
| 879 |
knowledge_graph_inst: BaseGraphStorage,
|
|
|
|
| 889 |
|
| 890 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
| 891 |
kw_prompt = kw_prompt_temp.format(query=query)
|
| 892 |
+
|
| 893 |
result = await use_model_func(kw_prompt)
|
| 894 |
try:
|
| 895 |
keywords_data = json.loads(result)
|
| 896 |
hl_keywords = keywords_data.get("high_level_keywords", [])
|
| 897 |
ll_keywords = keywords_data.get("low_level_keywords", [])
|
| 898 |
+
hl_keywords = ", ".join(hl_keywords)
|
| 899 |
+
ll_keywords = ", ".join(ll_keywords)
|
| 900 |
+
except json.JSONDecodeError:
|
| 901 |
try:
|
| 902 |
+
result = (
|
| 903 |
+
result.replace(kw_prompt[:-1], "")
|
| 904 |
+
.replace("user", "")
|
| 905 |
+
.replace("model", "")
|
| 906 |
+
.strip()
|
| 907 |
+
)
|
| 908 |
+
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
| 909 |
|
| 910 |
keywords_data = json.loads(result)
|
| 911 |
hl_keywords = keywords_data.get("high_level_keywords", [])
|
| 912 |
ll_keywords = keywords_data.get("low_level_keywords", [])
|
| 913 |
+
hl_keywords = ", ".join(hl_keywords)
|
| 914 |
+
ll_keywords = ", ".join(ll_keywords)
|
| 915 |
# Handle parsing error
|
| 916 |
except json.JSONDecodeError as e:
|
| 917 |
print(f"JSON parsing error: {e}")
|
|
|
|
| 942 |
return context
|
| 943 |
if context is None:
|
| 944 |
return PROMPTS["fail_response"]
|
| 945 |
+
|
| 946 |
sys_prompt_temp = PROMPTS["rag_response"]
|
| 947 |
sys_prompt = sys_prompt_temp.format(
|
| 948 |
context_data=context, response_type=query_param.response_type
|
|
|
|
| 951 |
query,
|
| 952 |
system_prompt=sys_prompt,
|
| 953 |
)
|
| 954 |
+
if len(response) > len(sys_prompt):
|
| 955 |
+
response = (
|
| 956 |
+
response.replace(sys_prompt, "")
|
| 957 |
+
.replace("user", "")
|
| 958 |
+
.replace("model", "")
|
| 959 |
+
.replace(query, "")
|
| 960 |
+
.replace("<system>", "")
|
| 961 |
+
.replace("</system>", "")
|
| 962 |
+
.strip()
|
| 963 |
+
)
|
| 964 |
return response
|
| 965 |
|
| 966 |
+
|
| 967 |
def combine_contexts(high_level_context, low_level_context):
|
| 968 |
# Function to extract entities, relationships, and sources from context strings
|
| 969 |
|
| 970 |
def extract_sections(context):
|
| 971 |
+
entities_match = re.search(
|
| 972 |
+
r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
|
| 973 |
+
)
|
| 974 |
+
relationships_match = re.search(
|
| 975 |
+
r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
|
| 976 |
+
)
|
| 977 |
+
sources_match = re.search(
|
| 978 |
+
r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
|
| 979 |
+
)
|
| 980 |
+
|
| 981 |
+
entities = entities_match.group(1) if entities_match else ""
|
| 982 |
+
relationships = relationships_match.group(1) if relationships_match else ""
|
| 983 |
+
sources = sources_match.group(1) if sources_match else ""
|
| 984 |
+
|
| 985 |
return entities, relationships, sources
|
| 986 |
+
|
| 987 |
# Extract sections from both contexts
|
| 988 |
|
| 989 |
+
if high_level_context is None:
|
| 990 |
+
warnings.warn(
|
| 991 |
+
"High Level context is None. Return empty High entity/relationship/source"
|
| 992 |
+
)
|
| 993 |
+
hl_entities, hl_relationships, hl_sources = "", "", ""
|
| 994 |
else:
|
| 995 |
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
|
| 996 |
|
| 997 |
+
if low_level_context is None:
|
| 998 |
+
warnings.warn(
|
| 999 |
+
"Low Level context is None. Return empty Low entity/relationship/source"
|
| 1000 |
+
)
|
| 1001 |
+
ll_entities, ll_relationships, ll_sources = "", "", ""
|
| 1002 |
else:
|
| 1003 |
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
|
| 1004 |
|
|
|
|
|
|
|
| 1005 |
# Combine and deduplicate the entities
|
| 1006 |
+
combined_entities_set = set(
|
| 1007 |
+
filter(None, hl_entities.strip().split("\n") + ll_entities.strip().split("\n"))
|
| 1008 |
+
)
|
| 1009 |
+
combined_entities = "\n".join(combined_entities_set)
|
| 1010 |
+
|
| 1011 |
# Combine and deduplicate the relationships
|
| 1012 |
+
combined_relationships_set = set(
|
| 1013 |
+
filter(
|
| 1014 |
+
None,
|
| 1015 |
+
hl_relationships.strip().split("\n") + ll_relationships.strip().split("\n"),
|
| 1016 |
+
)
|
| 1017 |
+
)
|
| 1018 |
+
combined_relationships = "\n".join(combined_relationships_set)
|
| 1019 |
+
|
| 1020 |
# Combine and deduplicate the sources
|
| 1021 |
+
combined_sources_set = set(
|
| 1022 |
+
filter(None, hl_sources.strip().split("\n") + ll_sources.strip().split("\n"))
|
| 1023 |
+
)
|
| 1024 |
+
combined_sources = "\n".join(combined_sources_set)
|
| 1025 |
+
|
| 1026 |
# Format the combined context
|
| 1027 |
return f"""
|
| 1028 |
-----Entities-----
|
|
|
|
| 1034 |
{combined_sources}
|
| 1035 |
"""
|
| 1036 |
|
| 1037 |
+
|
| 1038 |
async def naive_query(
|
| 1039 |
query,
|
| 1040 |
chunks_vdb: BaseVectorStorage,
|
|
|
|
| 1067 |
system_prompt=sys_prompt,
|
| 1068 |
)
|
| 1069 |
|
| 1070 |
+
if len(response) > len(sys_prompt):
|
| 1071 |
+
response = (
|
| 1072 |
+
response[len(sys_prompt) :]
|
| 1073 |
+
.replace(sys_prompt, "")
|
| 1074 |
+
.replace("user", "")
|
| 1075 |
+
.replace("model", "")
|
| 1076 |
+
.replace(query, "")
|
| 1077 |
+
.replace("<system>", "")
|
| 1078 |
+
.replace("</system>", "")
|
| 1079 |
+
.strip()
|
| 1080 |
+
)
|
| 1081 |
|
| 1082 |
+
return response
|
lightrag/prompt.py
CHANGED
|
@@ -9,9 +9,7 @@ PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "
|
|
| 9 |
|
| 10 |
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
|
| 11 |
|
| 12 |
-
PROMPTS[
|
| 13 |
-
"entity_extraction"
|
| 14 |
-
] = """-Goal-
|
| 15 |
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
|
| 16 |
|
| 17 |
-Steps-
|
|
@@ -32,7 +30,7 @@ Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tupl
|
|
| 32 |
|
| 33 |
3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
|
| 34 |
Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
|
| 35 |
-
|
| 36 |
4. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
|
| 37 |
|
| 38 |
5. When finished, output {completion_delimiter}
|
|
@@ -146,9 +144,7 @@ PROMPTS[
|
|
| 146 |
|
| 147 |
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
|
| 148 |
|
| 149 |
-
PROMPTS[
|
| 150 |
-
"rag_response"
|
| 151 |
-
] = """---Role---
|
| 152 |
|
| 153 |
You are a helpful assistant responding to questions about data in the tables provided.
|
| 154 |
|
|
@@ -241,9 +237,7 @@ Output:
|
|
| 241 |
|
| 242 |
"""
|
| 243 |
|
| 244 |
-
PROMPTS[
|
| 245 |
-
"naive_rag_response"
|
| 246 |
-
] = """You're a helpful assistant
|
| 247 |
Below are the knowledge you know:
|
| 248 |
{content_data}
|
| 249 |
---
|
|
|
|
| 9 |
|
| 10 |
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
|
| 11 |
|
| 12 |
+
PROMPTS["entity_extraction"] = """-Goal-
|
|
|
|
|
|
|
| 13 |
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
|
| 14 |
|
| 15 |
-Steps-
|
|
|
|
| 30 |
|
| 31 |
3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
|
| 32 |
Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
|
| 33 |
+
|
| 34 |
4. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
|
| 35 |
|
| 36 |
5. When finished, output {completion_delimiter}
|
|
|
|
| 144 |
|
| 145 |
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
|
| 146 |
|
| 147 |
+
PROMPTS["rag_response"] = """---Role---
|
|
|
|
|
|
|
| 148 |
|
| 149 |
You are a helpful assistant responding to questions about data in the tables provided.
|
| 150 |
|
|
|
|
| 237 |
|
| 238 |
"""
|
| 239 |
|
| 240 |
+
PROMPTS["naive_rag_response"] = """You're a helpful assistant
|
|
|
|
|
|
|
| 241 |
Below are the knowledge you know:
|
| 242 |
{content_data}
|
| 243 |
---
|
lightrag/storage.py
CHANGED
|
@@ -1,16 +1,11 @@
|
|
| 1 |
import asyncio
|
| 2 |
import html
|
| 3 |
-
import json
|
| 4 |
import os
|
| 5 |
-
from
|
| 6 |
-
from dataclasses import dataclass, field
|
| 7 |
from typing import Any, Union, cast
|
| 8 |
-
import pickle
|
| 9 |
-
import hnswlib
|
| 10 |
import networkx as nx
|
| 11 |
import numpy as np
|
| 12 |
from nano_vectordb import NanoVectorDB
|
| 13 |
-
import xxhash
|
| 14 |
|
| 15 |
from .utils import load_json, logger, write_json
|
| 16 |
from .base import (
|
|
@@ -19,6 +14,7 @@ from .base import (
|
|
| 19 |
BaseVectorStorage,
|
| 20 |
)
|
| 21 |
|
|
|
|
| 22 |
@dataclass
|
| 23 |
class JsonKVStorage(BaseKVStorage):
|
| 24 |
def __post_init__(self):
|
|
@@ -59,12 +55,12 @@ class JsonKVStorage(BaseKVStorage):
|
|
| 59 |
async def drop(self):
|
| 60 |
self._data = {}
|
| 61 |
|
|
|
|
| 62 |
@dataclass
|
| 63 |
class NanoVectorDBStorage(BaseVectorStorage):
|
| 64 |
cosine_better_than_threshold: float = 0.2
|
| 65 |
|
| 66 |
def __post_init__(self):
|
| 67 |
-
|
| 68 |
self._client_file_name = os.path.join(
|
| 69 |
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
| 70 |
)
|
|
@@ -118,6 +114,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
| 118 |
async def index_done_callback(self):
|
| 119 |
self._client.save()
|
| 120 |
|
|
|
|
| 121 |
@dataclass
|
| 122 |
class NetworkXStorage(BaseGraphStorage):
|
| 123 |
@staticmethod
|
|
@@ -142,7 +139,9 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 142 |
|
| 143 |
graph = graph.copy()
|
| 144 |
graph = cast(nx.Graph, largest_connected_component(graph))
|
| 145 |
-
node_mapping = {
|
|
|
|
|
|
|
| 146 |
graph = nx.relabel_nodes(graph, node_mapping)
|
| 147 |
return NetworkXStorage._stabilize_graph(graph)
|
| 148 |
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import html
|
|
|
|
| 3 |
import os
|
| 4 |
+
from dataclasses import dataclass
|
|
|
|
| 5 |
from typing import Any, Union, cast
|
|
|
|
|
|
|
| 6 |
import networkx as nx
|
| 7 |
import numpy as np
|
| 8 |
from nano_vectordb import NanoVectorDB
|
|
|
|
| 9 |
|
| 10 |
from .utils import load_json, logger, write_json
|
| 11 |
from .base import (
|
|
|
|
| 14 |
BaseVectorStorage,
|
| 15 |
)
|
| 16 |
|
| 17 |
+
|
| 18 |
@dataclass
|
| 19 |
class JsonKVStorage(BaseKVStorage):
|
| 20 |
def __post_init__(self):
|
|
|
|
| 55 |
async def drop(self):
|
| 56 |
self._data = {}
|
| 57 |
|
| 58 |
+
|
| 59 |
@dataclass
|
| 60 |
class NanoVectorDBStorage(BaseVectorStorage):
|
| 61 |
cosine_better_than_threshold: float = 0.2
|
| 62 |
|
| 63 |
def __post_init__(self):
|
|
|
|
| 64 |
self._client_file_name = os.path.join(
|
| 65 |
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
| 66 |
)
|
|
|
|
| 114 |
async def index_done_callback(self):
|
| 115 |
self._client.save()
|
| 116 |
|
| 117 |
+
|
| 118 |
@dataclass
|
| 119 |
class NetworkXStorage(BaseGraphStorage):
|
| 120 |
@staticmethod
|
|
|
|
| 139 |
|
| 140 |
graph = graph.copy()
|
| 141 |
graph = cast(nx.Graph, largest_connected_component(graph))
|
| 142 |
+
node_mapping = {
|
| 143 |
+
node: html.unescape(node.upper().strip()) for node in graph.nodes()
|
| 144 |
+
} # type: ignore
|
| 145 |
graph = nx.relabel_nodes(graph, node_mapping)
|
| 146 |
return NetworkXStorage._stabilize_graph(graph)
|
| 147 |
|
lightrag/utils.py
CHANGED
|
@@ -16,18 +16,22 @@ ENCODER = None
|
|
| 16 |
|
| 17 |
logger = logging.getLogger("lightrag")
|
| 18 |
|
|
|
|
| 19 |
def set_logger(log_file: str):
|
| 20 |
logger.setLevel(logging.DEBUG)
|
| 21 |
|
| 22 |
file_handler = logging.FileHandler(log_file)
|
| 23 |
file_handler.setLevel(logging.DEBUG)
|
| 24 |
|
| 25 |
-
formatter = logging.Formatter(
|
|
|
|
|
|
|
| 26 |
file_handler.setFormatter(formatter)
|
| 27 |
|
| 28 |
if not logger.handlers:
|
| 29 |
logger.addHandler(file_handler)
|
| 30 |
|
|
|
|
| 31 |
@dataclass
|
| 32 |
class EmbeddingFunc:
|
| 33 |
embedding_dim: int
|
|
@@ -36,7 +40,8 @@ class EmbeddingFunc:
|
|
| 36 |
|
| 37 |
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
| 38 |
return await self.func(*args, **kwargs)
|
| 39 |
-
|
|
|
|
| 40 |
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
| 41 |
"""Locate the JSON string body from a string"""
|
| 42 |
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
|
@@ -45,6 +50,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
|
| 45 |
else:
|
| 46 |
return None
|
| 47 |
|
|
|
|
| 48 |
def convert_response_to_json(response: str) -> dict:
|
| 49 |
json_str = locate_json_string_body_from_string(response)
|
| 50 |
assert json_str is not None, f"Unable to parse JSON from response: {response}"
|
|
@@ -55,12 +61,15 @@ def convert_response_to_json(response: str) -> dict:
|
|
| 55 |
logger.error(f"Failed to parse JSON: {json_str}")
|
| 56 |
raise e from None
|
| 57 |
|
|
|
|
| 58 |
def compute_args_hash(*args):
|
| 59 |
return md5(str(args).encode()).hexdigest()
|
| 60 |
|
|
|
|
| 61 |
def compute_mdhash_id(content, prefix: str = ""):
|
| 62 |
return prefix + md5(content.encode()).hexdigest()
|
| 63 |
|
|
|
|
| 64 |
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
| 65 |
"""Add restriction of maximum async calling times for a async func"""
|
| 66 |
|
|
@@ -82,6 +91,7 @@ def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
|
| 82 |
|
| 83 |
return final_decro
|
| 84 |
|
|
|
|
| 85 |
def wrap_embedding_func_with_attrs(**kwargs):
|
| 86 |
"""Wrap a function with attributes"""
|
| 87 |
|
|
@@ -91,16 +101,19 @@ def wrap_embedding_func_with_attrs(**kwargs):
|
|
| 91 |
|
| 92 |
return final_decro
|
| 93 |
|
|
|
|
| 94 |
def load_json(file_name):
|
| 95 |
if not os.path.exists(file_name):
|
| 96 |
return None
|
| 97 |
with open(file_name, encoding="utf-8") as f:
|
| 98 |
return json.load(f)
|
| 99 |
|
|
|
|
| 100 |
def write_json(json_obj, file_name):
|
| 101 |
with open(file_name, "w", encoding="utf-8") as f:
|
| 102 |
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
| 103 |
|
|
|
|
| 104 |
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
|
| 105 |
global ENCODER
|
| 106 |
if ENCODER is None:
|
|
@@ -116,12 +129,14 @@ def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
|
|
| 116 |
content = ENCODER.decode(tokens)
|
| 117 |
return content
|
| 118 |
|
|
|
|
| 119 |
def pack_user_ass_to_openai_messages(*args: str):
|
| 120 |
roles = ["user", "assistant"]
|
| 121 |
return [
|
| 122 |
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
|
| 123 |
]
|
| 124 |
|
|
|
|
| 125 |
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
| 126 |
"""Split a string by multiple markers"""
|
| 127 |
if not markers:
|
|
@@ -129,6 +144,7 @@ def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]
|
|
| 129 |
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
| 130 |
return [r.strip() for r in results if r.strip()]
|
| 131 |
|
|
|
|
| 132 |
# Refer the utils functions of the official GraphRAG implementation:
|
| 133 |
# https://github.com/microsoft/graphrag
|
| 134 |
def clean_str(input: Any) -> str:
|
|
@@ -141,9 +157,11 @@ def clean_str(input: Any) -> str:
|
|
| 141 |
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
|
| 142 |
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
| 143 |
|
|
|
|
| 144 |
def is_float_regex(value):
|
| 145 |
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
| 146 |
|
|
|
|
| 147 |
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
|
| 148 |
"""Truncate a list of data by token size"""
|
| 149 |
if max_token_size <= 0:
|
|
@@ -155,11 +173,13 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
|
|
| 155 |
return list_data[:i]
|
| 156 |
return list_data
|
| 157 |
|
|
|
|
| 158 |
def list_of_list_to_csv(data: list[list]):
|
| 159 |
return "\n".join(
|
| 160 |
[",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
|
| 161 |
)
|
| 162 |
|
|
|
|
| 163 |
def save_data_to_file(data, file_name):
|
| 164 |
-
with open(file_name,
|
| 165 |
-
json.dump(data, f, ensure_ascii=False, indent=4)
|
|
|
|
| 16 |
|
| 17 |
logger = logging.getLogger("lightrag")
|
| 18 |
|
| 19 |
+
|
| 20 |
def set_logger(log_file: str):
|
| 21 |
logger.setLevel(logging.DEBUG)
|
| 22 |
|
| 23 |
file_handler = logging.FileHandler(log_file)
|
| 24 |
file_handler.setLevel(logging.DEBUG)
|
| 25 |
|
| 26 |
+
formatter = logging.Formatter(
|
| 27 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 28 |
+
)
|
| 29 |
file_handler.setFormatter(formatter)
|
| 30 |
|
| 31 |
if not logger.handlers:
|
| 32 |
logger.addHandler(file_handler)
|
| 33 |
|
| 34 |
+
|
| 35 |
@dataclass
|
| 36 |
class EmbeddingFunc:
|
| 37 |
embedding_dim: int
|
|
|
|
| 40 |
|
| 41 |
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
| 42 |
return await self.func(*args, **kwargs)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
| 46 |
"""Locate the JSON string body from a string"""
|
| 47 |
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
|
|
|
| 50 |
else:
|
| 51 |
return None
|
| 52 |
|
| 53 |
+
|
| 54 |
def convert_response_to_json(response: str) -> dict:
|
| 55 |
json_str = locate_json_string_body_from_string(response)
|
| 56 |
assert json_str is not None, f"Unable to parse JSON from response: {response}"
|
|
|
|
| 61 |
logger.error(f"Failed to parse JSON: {json_str}")
|
| 62 |
raise e from None
|
| 63 |
|
| 64 |
+
|
| 65 |
def compute_args_hash(*args):
|
| 66 |
return md5(str(args).encode()).hexdigest()
|
| 67 |
|
| 68 |
+
|
| 69 |
def compute_mdhash_id(content, prefix: str = ""):
|
| 70 |
return prefix + md5(content.encode()).hexdigest()
|
| 71 |
|
| 72 |
+
|
| 73 |
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
| 74 |
"""Add restriction of maximum async calling times for a async func"""
|
| 75 |
|
|
|
|
| 91 |
|
| 92 |
return final_decro
|
| 93 |
|
| 94 |
+
|
| 95 |
def wrap_embedding_func_with_attrs(**kwargs):
|
| 96 |
"""Wrap a function with attributes"""
|
| 97 |
|
|
|
|
| 101 |
|
| 102 |
return final_decro
|
| 103 |
|
| 104 |
+
|
| 105 |
def load_json(file_name):
|
| 106 |
if not os.path.exists(file_name):
|
| 107 |
return None
|
| 108 |
with open(file_name, encoding="utf-8") as f:
|
| 109 |
return json.load(f)
|
| 110 |
|
| 111 |
+
|
| 112 |
def write_json(json_obj, file_name):
|
| 113 |
with open(file_name, "w", encoding="utf-8") as f:
|
| 114 |
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
| 115 |
|
| 116 |
+
|
| 117 |
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
|
| 118 |
global ENCODER
|
| 119 |
if ENCODER is None:
|
|
|
|
| 129 |
content = ENCODER.decode(tokens)
|
| 130 |
return content
|
| 131 |
|
| 132 |
+
|
| 133 |
def pack_user_ass_to_openai_messages(*args: str):
|
| 134 |
roles = ["user", "assistant"]
|
| 135 |
return [
|
| 136 |
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
|
| 137 |
]
|
| 138 |
|
| 139 |
+
|
| 140 |
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
| 141 |
"""Split a string by multiple markers"""
|
| 142 |
if not markers:
|
|
|
|
| 144 |
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
| 145 |
return [r.strip() for r in results if r.strip()]
|
| 146 |
|
| 147 |
+
|
| 148 |
# Refer the utils functions of the official GraphRAG implementation:
|
| 149 |
# https://github.com/microsoft/graphrag
|
| 150 |
def clean_str(input: Any) -> str:
|
|
|
|
| 157 |
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
|
| 158 |
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
| 159 |
|
| 160 |
+
|
| 161 |
def is_float_regex(value):
|
| 162 |
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
| 163 |
|
| 164 |
+
|
| 165 |
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
|
| 166 |
"""Truncate a list of data by token size"""
|
| 167 |
if max_token_size <= 0:
|
|
|
|
| 173 |
return list_data[:i]
|
| 174 |
return list_data
|
| 175 |
|
| 176 |
+
|
| 177 |
def list_of_list_to_csv(data: list[list]):
|
| 178 |
return "\n".join(
|
| 179 |
[",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
|
| 180 |
)
|
| 181 |
|
| 182 |
+
|
| 183 |
def save_data_to_file(data, file_name):
|
| 184 |
+
with open(file_name, "w", encoding="utf-8") as f:
|
| 185 |
+
json.dump(data, f, ensure_ascii=False, indent=4)
|
reproduce/Step_0.py
CHANGED
|
@@ -3,11 +3,11 @@ import json
|
|
| 3 |
import glob
|
| 4 |
import argparse
|
| 5 |
|
| 6 |
-
def extract_unique_contexts(input_directory, output_directory):
|
| 7 |
|
|
|
|
| 8 |
os.makedirs(output_directory, exist_ok=True)
|
| 9 |
|
| 10 |
-
jsonl_files = glob.glob(os.path.join(input_directory,
|
| 11 |
print(f"Found {len(jsonl_files)} JSONL files.")
|
| 12 |
|
| 13 |
for file_path in jsonl_files:
|
|
@@ -21,18 +21,20 @@ def extract_unique_contexts(input_directory, output_directory):
|
|
| 21 |
print(f"Processing file: {filename}")
|
| 22 |
|
| 23 |
try:
|
| 24 |
-
with open(file_path,
|
| 25 |
for line_number, line in enumerate(infile, start=1):
|
| 26 |
line = line.strip()
|
| 27 |
if not line:
|
| 28 |
continue
|
| 29 |
try:
|
| 30 |
json_obj = json.loads(line)
|
| 31 |
-
context = json_obj.get(
|
| 32 |
if context and context not in unique_contexts_dict:
|
| 33 |
unique_contexts_dict[context] = None
|
| 34 |
except json.JSONDecodeError as e:
|
| 35 |
-
print(
|
|
|
|
|
|
|
| 36 |
except FileNotFoundError:
|
| 37 |
print(f"File not found: {filename}")
|
| 38 |
continue
|
|
@@ -41,10 +43,12 @@ def extract_unique_contexts(input_directory, output_directory):
|
|
| 41 |
continue
|
| 42 |
|
| 43 |
unique_contexts_list = list(unique_contexts_dict.keys())
|
| 44 |
-
print(
|
|
|
|
|
|
|
| 45 |
|
| 46 |
try:
|
| 47 |
-
with open(output_path,
|
| 48 |
json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
|
| 49 |
print(f"Unique `context` entries have been saved to: {output_filename}")
|
| 50 |
except Exception as e:
|
|
@@ -55,8 +59,10 @@ def extract_unique_contexts(input_directory, output_directory):
|
|
| 55 |
|
| 56 |
if __name__ == "__main__":
|
| 57 |
parser = argparse.ArgumentParser()
|
| 58 |
-
parser.add_argument(
|
| 59 |
-
parser.add_argument(
|
|
|
|
|
|
|
| 60 |
|
| 61 |
args = parser.parse_args()
|
| 62 |
|
|
|
|
| 3 |
import glob
|
| 4 |
import argparse
|
| 5 |
|
|
|
|
| 6 |
|
| 7 |
+
def extract_unique_contexts(input_directory, output_directory):
|
| 8 |
os.makedirs(output_directory, exist_ok=True)
|
| 9 |
|
| 10 |
+
jsonl_files = glob.glob(os.path.join(input_directory, "*.jsonl"))
|
| 11 |
print(f"Found {len(jsonl_files)} JSONL files.")
|
| 12 |
|
| 13 |
for file_path in jsonl_files:
|
|
|
|
| 21 |
print(f"Processing file: {filename}")
|
| 22 |
|
| 23 |
try:
|
| 24 |
+
with open(file_path, "r", encoding="utf-8") as infile:
|
| 25 |
for line_number, line in enumerate(infile, start=1):
|
| 26 |
line = line.strip()
|
| 27 |
if not line:
|
| 28 |
continue
|
| 29 |
try:
|
| 30 |
json_obj = json.loads(line)
|
| 31 |
+
context = json_obj.get("context")
|
| 32 |
if context and context not in unique_contexts_dict:
|
| 33 |
unique_contexts_dict[context] = None
|
| 34 |
except json.JSONDecodeError as e:
|
| 35 |
+
print(
|
| 36 |
+
f"JSON decoding error in file {filename} at line {line_number}: {e}"
|
| 37 |
+
)
|
| 38 |
except FileNotFoundError:
|
| 39 |
print(f"File not found: {filename}")
|
| 40 |
continue
|
|
|
|
| 43 |
continue
|
| 44 |
|
| 45 |
unique_contexts_list = list(unique_contexts_dict.keys())
|
| 46 |
+
print(
|
| 47 |
+
f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}."
|
| 48 |
+
)
|
| 49 |
|
| 50 |
try:
|
| 51 |
+
with open(output_path, "w", encoding="utf-8") as outfile:
|
| 52 |
json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
|
| 53 |
print(f"Unique `context` entries have been saved to: {output_filename}")
|
| 54 |
except Exception as e:
|
|
|
|
| 59 |
|
| 60 |
if __name__ == "__main__":
|
| 61 |
parser = argparse.ArgumentParser()
|
| 62 |
+
parser.add_argument("-i", "--input_dir", type=str, default="../datasets")
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"-o", "--output_dir", type=str, default="../datasets/unique_contexts"
|
| 65 |
+
)
|
| 66 |
|
| 67 |
args = parser.parse_args()
|
| 68 |
|
reproduce/Step_1.py
CHANGED
|
@@ -4,10 +4,11 @@ import time
|
|
| 4 |
|
| 5 |
from lightrag import LightRAG
|
| 6 |
|
|
|
|
| 7 |
def insert_text(rag, file_path):
|
| 8 |
-
with open(file_path, mode=
|
| 9 |
unique_contexts = json.load(f)
|
| 10 |
-
|
| 11 |
retries = 0
|
| 12 |
max_retries = 3
|
| 13 |
while retries < max_retries:
|
|
@@ -21,6 +22,7 @@ def insert_text(rag, file_path):
|
|
| 21 |
if retries == max_retries:
|
| 22 |
print("Insertion failed after exceeding the maximum number of retries")
|
| 23 |
|
|
|
|
| 24 |
cls = "agriculture"
|
| 25 |
WORKING_DIR = "../{cls}"
|
| 26 |
|
|
@@ -29,4 +31,4 @@ if not os.path.exists(WORKING_DIR):
|
|
| 29 |
|
| 30 |
rag = LightRAG(working_dir=WORKING_DIR)
|
| 31 |
|
| 32 |
-
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
|
|
|
|
| 4 |
|
| 5 |
from lightrag import LightRAG
|
| 6 |
|
| 7 |
+
|
| 8 |
def insert_text(rag, file_path):
|
| 9 |
+
with open(file_path, mode="r") as f:
|
| 10 |
unique_contexts = json.load(f)
|
| 11 |
+
|
| 12 |
retries = 0
|
| 13 |
max_retries = 3
|
| 14 |
while retries < max_retries:
|
|
|
|
| 22 |
if retries == max_retries:
|
| 23 |
print("Insertion failed after exceeding the maximum number of retries")
|
| 24 |
|
| 25 |
+
|
| 26 |
cls = "agriculture"
|
| 27 |
WORKING_DIR = "../{cls}"
|
| 28 |
|
|
|
|
| 31 |
|
| 32 |
rag = LightRAG(working_dir=WORKING_DIR)
|
| 33 |
|
| 34 |
+
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
|
reproduce/Step_1_openai_compatible.py
CHANGED
|
@@ -7,6 +7,7 @@ from lightrag import LightRAG
|
|
| 7 |
from lightrag.utils import EmbeddingFunc
|
| 8 |
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
| 9 |
|
|
|
|
| 10 |
## For Upstage API
|
| 11 |
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
|
| 12 |
async def llm_model_func(
|
|
@@ -19,22 +20,26 @@ async def llm_model_func(
|
|
| 19 |
history_messages=history_messages,
|
| 20 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
| 21 |
base_url="https://api.upstage.ai/v1/solar",
|
| 22 |
-
**kwargs
|
| 23 |
)
|
| 24 |
|
|
|
|
| 25 |
async def embedding_func(texts: list[str]) -> np.ndarray:
|
| 26 |
return await openai_embedding(
|
| 27 |
texts,
|
| 28 |
model="solar-embedding-1-large-query",
|
| 29 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
| 30 |
-
base_url="https://api.upstage.ai/v1/solar"
|
| 31 |
)
|
|
|
|
|
|
|
| 32 |
## /For Upstage API
|
| 33 |
|
|
|
|
| 34 |
def insert_text(rag, file_path):
|
| 35 |
-
with open(file_path, mode=
|
| 36 |
unique_contexts = json.load(f)
|
| 37 |
-
|
| 38 |
retries = 0
|
| 39 |
max_retries = 3
|
| 40 |
while retries < max_retries:
|
|
@@ -48,19 +53,19 @@ def insert_text(rag, file_path):
|
|
| 48 |
if retries == max_retries:
|
| 49 |
print("Insertion failed after exceeding the maximum number of retries")
|
| 50 |
|
|
|
|
| 51 |
cls = "mix"
|
| 52 |
WORKING_DIR = f"../{cls}"
|
| 53 |
|
| 54 |
if not os.path.exists(WORKING_DIR):
|
| 55 |
os.mkdir(WORKING_DIR)
|
| 56 |
|
| 57 |
-
rag = LightRAG(
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
)
|
| 65 |
|
| 66 |
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
|
|
|
|
| 7 |
from lightrag.utils import EmbeddingFunc
|
| 8 |
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
| 9 |
|
| 10 |
+
|
| 11 |
## For Upstage API
|
| 12 |
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
|
| 13 |
async def llm_model_func(
|
|
|
|
| 20 |
history_messages=history_messages,
|
| 21 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
| 22 |
base_url="https://api.upstage.ai/v1/solar",
|
| 23 |
+
**kwargs,
|
| 24 |
)
|
| 25 |
|
| 26 |
+
|
| 27 |
async def embedding_func(texts: list[str]) -> np.ndarray:
|
| 28 |
return await openai_embedding(
|
| 29 |
texts,
|
| 30 |
model="solar-embedding-1-large-query",
|
| 31 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
| 32 |
+
base_url="https://api.upstage.ai/v1/solar",
|
| 33 |
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
## /For Upstage API
|
| 37 |
|
| 38 |
+
|
| 39 |
def insert_text(rag, file_path):
|
| 40 |
+
with open(file_path, mode="r") as f:
|
| 41 |
unique_contexts = json.load(f)
|
| 42 |
+
|
| 43 |
retries = 0
|
| 44 |
max_retries = 3
|
| 45 |
while retries < max_retries:
|
|
|
|
| 53 |
if retries == max_retries:
|
| 54 |
print("Insertion failed after exceeding the maximum number of retries")
|
| 55 |
|
| 56 |
+
|
| 57 |
cls = "mix"
|
| 58 |
WORKING_DIR = f"../{cls}"
|
| 59 |
|
| 60 |
if not os.path.exists(WORKING_DIR):
|
| 61 |
os.mkdir(WORKING_DIR)
|
| 62 |
|
| 63 |
+
rag = LightRAG(
|
| 64 |
+
working_dir=WORKING_DIR,
|
| 65 |
+
llm_model_func=llm_model_func,
|
| 66 |
+
embedding_func=EmbeddingFunc(
|
| 67 |
+
embedding_dim=4096, max_token_size=8192, func=embedding_func
|
| 68 |
+
),
|
| 69 |
+
)
|
|
|
|
| 70 |
|
| 71 |
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
|
reproduce/Step_2.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
-
import os
|
| 2 |
import json
|
| 3 |
from openai import OpenAI
|
| 4 |
from transformers import GPT2Tokenizer
|
| 5 |
|
|
|
|
| 6 |
def openai_complete_if_cache(
|
| 7 |
model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs
|
| 8 |
) -> str:
|
|
@@ -19,24 +19,26 @@ def openai_complete_if_cache(
|
|
| 19 |
)
|
| 20 |
return response.choices[0].message.content
|
| 21 |
|
| 22 |
-
|
|
|
|
|
|
|
| 23 |
|
| 24 |
def get_summary(context, tot_tokens=2000):
|
| 25 |
tokens = tokenizer.tokenize(context)
|
| 26 |
half_tokens = tot_tokens // 2
|
| 27 |
|
| 28 |
-
start_tokens = tokens[1000:1000 + half_tokens]
|
| 29 |
-
end_tokens = tokens[-(1000 + half_tokens):1000]
|
| 30 |
|
| 31 |
summary_tokens = start_tokens + end_tokens
|
| 32 |
summary = tokenizer.convert_tokens_to_string(summary_tokens)
|
| 33 |
-
|
| 34 |
return summary
|
| 35 |
|
| 36 |
|
| 37 |
-
clses = [
|
| 38 |
for cls in clses:
|
| 39 |
-
with open(f
|
| 40 |
unique_contexts = json.load(f)
|
| 41 |
|
| 42 |
summaries = [get_summary(context) for context in unique_contexts]
|
|
@@ -67,10 +69,10 @@ for cls in clses:
|
|
| 67 |
...
|
| 68 |
"""
|
| 69 |
|
| 70 |
-
result = openai_complete_if_cache(model=
|
| 71 |
|
| 72 |
file_path = f"../datasets/questions/{cls}_questions.txt"
|
| 73 |
with open(file_path, "w") as file:
|
| 74 |
file.write(result)
|
| 75 |
|
| 76 |
-
print(f"{cls}_questions written to {file_path}")
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
from openai import OpenAI
|
| 3 |
from transformers import GPT2Tokenizer
|
| 4 |
|
| 5 |
+
|
| 6 |
def openai_complete_if_cache(
|
| 7 |
model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs
|
| 8 |
) -> str:
|
|
|
|
| 19 |
)
|
| 20 |
return response.choices[0].message.content
|
| 21 |
|
| 22 |
+
|
| 23 |
+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
| 24 |
+
|
| 25 |
|
| 26 |
def get_summary(context, tot_tokens=2000):
|
| 27 |
tokens = tokenizer.tokenize(context)
|
| 28 |
half_tokens = tot_tokens // 2
|
| 29 |
|
| 30 |
+
start_tokens = tokens[1000 : 1000 + half_tokens]
|
| 31 |
+
end_tokens = tokens[-(1000 + half_tokens) : 1000]
|
| 32 |
|
| 33 |
summary_tokens = start_tokens + end_tokens
|
| 34 |
summary = tokenizer.convert_tokens_to_string(summary_tokens)
|
| 35 |
+
|
| 36 |
return summary
|
| 37 |
|
| 38 |
|
| 39 |
+
clses = ["agriculture"]
|
| 40 |
for cls in clses:
|
| 41 |
+
with open(f"../datasets/unique_contexts/{cls}_unique_contexts.json", mode="r") as f:
|
| 42 |
unique_contexts = json.load(f)
|
| 43 |
|
| 44 |
summaries = [get_summary(context) for context in unique_contexts]
|
|
|
|
| 69 |
...
|
| 70 |
"""
|
| 71 |
|
| 72 |
+
result = openai_complete_if_cache(model="gpt-4o", prompt=prompt)
|
| 73 |
|
| 74 |
file_path = f"../datasets/questions/{cls}_questions.txt"
|
| 75 |
with open(file_path, "w") as file:
|
| 76 |
file.write(result)
|
| 77 |
|
| 78 |
+
print(f"{cls}_questions written to {file_path}")
|
reproduce/Step_3.py
CHANGED
|
@@ -4,16 +4,18 @@ import asyncio
|
|
| 4 |
from lightrag import LightRAG, QueryParam
|
| 5 |
from tqdm import tqdm
|
| 6 |
|
|
|
|
| 7 |
def extract_queries(file_path):
|
| 8 |
-
with open(file_path,
|
| 9 |
data = f.read()
|
| 10 |
-
|
| 11 |
-
data = data.replace('**', '')
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
| 14 |
|
| 15 |
return queries
|
| 16 |
|
|
|
|
| 17 |
async def process_query(query_text, rag_instance, query_param):
|
| 18 |
try:
|
| 19 |
result, context = await rag_instance.aquery(query_text, param=query_param)
|
|
@@ -21,6 +23,7 @@ async def process_query(query_text, rag_instance, query_param):
|
|
| 21 |
except Exception as e:
|
| 22 |
return None, {"query": query_text, "error": str(e)}
|
| 23 |
|
|
|
|
| 24 |
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
| 25 |
try:
|
| 26 |
loop = asyncio.get_event_loop()
|
|
@@ -29,15 +32,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|
| 29 |
asyncio.set_event_loop(loop)
|
| 30 |
return loop
|
| 31 |
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
| 33 |
loop = always_get_an_event_loop()
|
| 34 |
|
| 35 |
-
with open(output_file,
|
|
|
|
|
|
|
| 36 |
result_file.write("[\n")
|
| 37 |
first_entry = True
|
| 38 |
|
| 39 |
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
|
| 40 |
-
result, error = loop.run_until_complete(
|
|
|
|
|
|
|
| 41 |
|
| 42 |
if result:
|
| 43 |
if not first_entry:
|
|
@@ -50,6 +60,7 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
|
|
| 50 |
|
| 51 |
result_file.write("\n]")
|
| 52 |
|
|
|
|
| 53 |
if __name__ == "__main__":
|
| 54 |
cls = "agriculture"
|
| 55 |
mode = "hybrid"
|
|
@@ -59,4 +70,6 @@ if __name__ == "__main__":
|
|
| 59 |
query_param = QueryParam(mode=mode)
|
| 60 |
|
| 61 |
queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
|
| 62 |
-
run_queries_and_save_to_json(
|
|
|
|
|
|
|
|
|
| 4 |
from lightrag import LightRAG, QueryParam
|
| 5 |
from tqdm import tqdm
|
| 6 |
|
| 7 |
+
|
| 8 |
def extract_queries(file_path):
|
| 9 |
+
with open(file_path, "r") as f:
|
| 10 |
data = f.read()
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
data = data.replace("**", "")
|
| 13 |
+
|
| 14 |
+
queries = re.findall(r"- Question \d+: (.+)", data)
|
| 15 |
|
| 16 |
return queries
|
| 17 |
|
| 18 |
+
|
| 19 |
async def process_query(query_text, rag_instance, query_param):
|
| 20 |
try:
|
| 21 |
result, context = await rag_instance.aquery(query_text, param=query_param)
|
|
|
|
| 23 |
except Exception as e:
|
| 24 |
return None, {"query": query_text, "error": str(e)}
|
| 25 |
|
| 26 |
+
|
| 27 |
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
| 28 |
try:
|
| 29 |
loop = asyncio.get_event_loop()
|
|
|
|
| 32 |
asyncio.set_event_loop(loop)
|
| 33 |
return loop
|
| 34 |
|
| 35 |
+
|
| 36 |
+
def run_queries_and_save_to_json(
|
| 37 |
+
queries, rag_instance, query_param, output_file, error_file
|
| 38 |
+
):
|
| 39 |
loop = always_get_an_event_loop()
|
| 40 |
|
| 41 |
+
with open(output_file, "a", encoding="utf-8") as result_file, open(
|
| 42 |
+
error_file, "a", encoding="utf-8"
|
| 43 |
+
) as err_file:
|
| 44 |
result_file.write("[\n")
|
| 45 |
first_entry = True
|
| 46 |
|
| 47 |
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
|
| 48 |
+
result, error = loop.run_until_complete(
|
| 49 |
+
process_query(query_text, rag_instance, query_param)
|
| 50 |
+
)
|
| 51 |
|
| 52 |
if result:
|
| 53 |
if not first_entry:
|
|
|
|
| 60 |
|
| 61 |
result_file.write("\n]")
|
| 62 |
|
| 63 |
+
|
| 64 |
if __name__ == "__main__":
|
| 65 |
cls = "agriculture"
|
| 66 |
mode = "hybrid"
|
|
|
|
| 70 |
query_param = QueryParam(mode=mode)
|
| 71 |
|
| 72 |
queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
|
| 73 |
+
run_queries_and_save_to_json(
|
| 74 |
+
queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json"
|
| 75 |
+
)
|
reproduce/Step_3_openai_compatible.py
CHANGED
|
@@ -8,6 +8,7 @@ from lightrag.llm import openai_complete_if_cache, openai_embedding
|
|
| 8 |
from lightrag.utils import EmbeddingFunc
|
| 9 |
import numpy as np
|
| 10 |
|
|
|
|
| 11 |
## For Upstage API
|
| 12 |
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
|
| 13 |
async def llm_model_func(
|
|
@@ -20,28 +21,33 @@ async def llm_model_func(
|
|
| 20 |
history_messages=history_messages,
|
| 21 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
| 22 |
base_url="https://api.upstage.ai/v1/solar",
|
| 23 |
-
**kwargs
|
| 24 |
)
|
| 25 |
|
|
|
|
| 26 |
async def embedding_func(texts: list[str]) -> np.ndarray:
|
| 27 |
return await openai_embedding(
|
| 28 |
texts,
|
| 29 |
model="solar-embedding-1-large-query",
|
| 30 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
| 31 |
-
base_url="https://api.upstage.ai/v1/solar"
|
| 32 |
)
|
|
|
|
|
|
|
| 33 |
## /For Upstage API
|
| 34 |
|
|
|
|
| 35 |
def extract_queries(file_path):
|
| 36 |
-
with open(file_path,
|
| 37 |
data = f.read()
|
| 38 |
-
|
| 39 |
-
data = data.replace('**', '')
|
| 40 |
|
| 41 |
-
|
|
|
|
|
|
|
| 42 |
|
| 43 |
return queries
|
| 44 |
|
|
|
|
| 45 |
async def process_query(query_text, rag_instance, query_param):
|
| 46 |
try:
|
| 47 |
result, context = await rag_instance.aquery(query_text, param=query_param)
|
|
@@ -49,6 +55,7 @@ async def process_query(query_text, rag_instance, query_param):
|
|
| 49 |
except Exception as e:
|
| 50 |
return None, {"query": query_text, "error": str(e)}
|
| 51 |
|
|
|
|
| 52 |
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
| 53 |
try:
|
| 54 |
loop = asyncio.get_event_loop()
|
|
@@ -57,15 +64,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|
| 57 |
asyncio.set_event_loop(loop)
|
| 58 |
return loop
|
| 59 |
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
| 61 |
loop = always_get_an_event_loop()
|
| 62 |
|
| 63 |
-
with open(output_file,
|
|
|
|
|
|
|
| 64 |
result_file.write("[\n")
|
| 65 |
first_entry = True
|
| 66 |
|
| 67 |
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
|
| 68 |
-
result, error = loop.run_until_complete(
|
|
|
|
|
|
|
| 69 |
|
| 70 |
if result:
|
| 71 |
if not first_entry:
|
|
@@ -78,22 +92,24 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
|
|
| 78 |
|
| 79 |
result_file.write("\n]")
|
| 80 |
|
|
|
|
| 81 |
if __name__ == "__main__":
|
| 82 |
cls = "mix"
|
| 83 |
mode = "hybrid"
|
| 84 |
WORKING_DIR = f"../{cls}"
|
| 85 |
|
| 86 |
rag = LightRAG(working_dir=WORKING_DIR)
|
| 87 |
-
rag = LightRAG(
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
)
|
| 95 |
query_param = QueryParam(mode=mode)
|
| 96 |
|
| 97 |
-
base_dir=
|
| 98 |
queries = extract_queries(f"{base_dir}/{cls}_questions.txt")
|
| 99 |
-
run_queries_and_save_to_json(
|
|
|
|
|
|
|
|
|
| 8 |
from lightrag.utils import EmbeddingFunc
|
| 9 |
import numpy as np
|
| 10 |
|
| 11 |
+
|
| 12 |
## For Upstage API
|
| 13 |
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
|
| 14 |
async def llm_model_func(
|
|
|
|
| 21 |
history_messages=history_messages,
|
| 22 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
| 23 |
base_url="https://api.upstage.ai/v1/solar",
|
| 24 |
+
**kwargs,
|
| 25 |
)
|
| 26 |
|
| 27 |
+
|
| 28 |
async def embedding_func(texts: list[str]) -> np.ndarray:
|
| 29 |
return await openai_embedding(
|
| 30 |
texts,
|
| 31 |
model="solar-embedding-1-large-query",
|
| 32 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
| 33 |
+
base_url="https://api.upstage.ai/v1/solar",
|
| 34 |
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
## /For Upstage API
|
| 38 |
|
| 39 |
+
|
| 40 |
def extract_queries(file_path):
|
| 41 |
+
with open(file_path, "r") as f:
|
| 42 |
data = f.read()
|
|
|
|
|
|
|
| 43 |
|
| 44 |
+
data = data.replace("**", "")
|
| 45 |
+
|
| 46 |
+
queries = re.findall(r"- Question \d+: (.+)", data)
|
| 47 |
|
| 48 |
return queries
|
| 49 |
|
| 50 |
+
|
| 51 |
async def process_query(query_text, rag_instance, query_param):
|
| 52 |
try:
|
| 53 |
result, context = await rag_instance.aquery(query_text, param=query_param)
|
|
|
|
| 55 |
except Exception as e:
|
| 56 |
return None, {"query": query_text, "error": str(e)}
|
| 57 |
|
| 58 |
+
|
| 59 |
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
| 60 |
try:
|
| 61 |
loop = asyncio.get_event_loop()
|
|
|
|
| 64 |
asyncio.set_event_loop(loop)
|
| 65 |
return loop
|
| 66 |
|
| 67 |
+
|
| 68 |
+
def run_queries_and_save_to_json(
|
| 69 |
+
queries, rag_instance, query_param, output_file, error_file
|
| 70 |
+
):
|
| 71 |
loop = always_get_an_event_loop()
|
| 72 |
|
| 73 |
+
with open(output_file, "a", encoding="utf-8") as result_file, open(
|
| 74 |
+
error_file, "a", encoding="utf-8"
|
| 75 |
+
) as err_file:
|
| 76 |
result_file.write("[\n")
|
| 77 |
first_entry = True
|
| 78 |
|
| 79 |
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
|
| 80 |
+
result, error = loop.run_until_complete(
|
| 81 |
+
process_query(query_text, rag_instance, query_param)
|
| 82 |
+
)
|
| 83 |
|
| 84 |
if result:
|
| 85 |
if not first_entry:
|
|
|
|
| 92 |
|
| 93 |
result_file.write("\n]")
|
| 94 |
|
| 95 |
+
|
| 96 |
if __name__ == "__main__":
|
| 97 |
cls = "mix"
|
| 98 |
mode = "hybrid"
|
| 99 |
WORKING_DIR = f"../{cls}"
|
| 100 |
|
| 101 |
rag = LightRAG(working_dir=WORKING_DIR)
|
| 102 |
+
rag = LightRAG(
|
| 103 |
+
working_dir=WORKING_DIR,
|
| 104 |
+
llm_model_func=llm_model_func,
|
| 105 |
+
embedding_func=EmbeddingFunc(
|
| 106 |
+
embedding_dim=4096, max_token_size=8192, func=embedding_func
|
| 107 |
+
),
|
| 108 |
+
)
|
|
|
|
| 109 |
query_param = QueryParam(mode=mode)
|
| 110 |
|
| 111 |
+
base_dir = "../datasets/questions"
|
| 112 |
queries = extract_queries(f"{base_dir}/{cls}_questions.txt")
|
| 113 |
+
run_queries_and_save_to_json(
|
| 114 |
+
queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json"
|
| 115 |
+
)
|
requirements.txt
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
|
|
| 1 |
aioboto3
|
| 2 |
-
openai
|
| 3 |
-
tiktoken
|
| 4 |
-
networkx
|
| 5 |
graspologic
|
| 6 |
-
nano-vectordb
|
| 7 |
hnswlib
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
| 9 |
tenacity
|
| 10 |
-
|
| 11 |
torch
|
| 12 |
-
|
| 13 |
-
|
|
|
|
| 1 |
+
accelerate
|
| 2 |
aioboto3
|
|
|
|
|
|
|
|
|
|
| 3 |
graspologic
|
|
|
|
| 4 |
hnswlib
|
| 5 |
+
nano-vectordb
|
| 6 |
+
networkx
|
| 7 |
+
ollama
|
| 8 |
+
openai
|
| 9 |
tenacity
|
| 10 |
+
tiktoken
|
| 11 |
torch
|
| 12 |
+
transformers
|
| 13 |
+
xxhash
|