Initial code release
Browse filesCode for memory-retrieval experiments.
This view is limited to 50 files because it contains too many changes. See raw diff
- README.md +163 -0
- baselines/MemoChat/LICENSE +21 -0
- baselines/MemoChat/README.md +47 -0
- baselines/MemoChat/code/codes/api/gpt_2k.py +90 -0
- baselines/MemoChat/code/codes/api/gpt_memochat.py +189 -0
- baselines/MemoChat/code/codes/api/llm_judge.py +96 -0
- baselines/MemoChat/code/codes/eval/eval_instruction_tuning_tasks.py +169 -0
- baselines/MemoChat/code/codes/eval/get_model_infer_memochat.py +291 -0
- baselines/MemoChat/code/codes/eval/get_model_infer_simple.py +150 -0
- baselines/MemoChat/code/codes/train/data_preprocess.py +118 -0
- baselines/MemoChat/code/codes/train/train.py +150 -0
- baselines/MemoChat/code/configs/ds_config_13b.json +53 -0
- baselines/MemoChat/code/configs/ds_config_33b.json +57 -0
- baselines/MemoChat/code/configs/ds_config_3b.json +39 -0
- baselines/MemoChat/code/configs/ds_config_7b.json +49 -0
- baselines/MemoChat/code/scripts/llm_judge.sh +35 -0
- baselines/MemoChat/code/scripts/memochat.sh +34 -0
- baselines/MemoChat/code/scripts/memochat_gpt.sh +18 -0
- baselines/MemoChat/code/scripts/tuning.sh +110 -0
- baselines/MemoChat/core_requirement.txt +13 -0
- baselines/MemoChat/run_memochat_baseline.py +634 -0
- baselines/raptor/LICENSE.txt +21 -0
- baselines/raptor/README.md +204 -0
- baselines/raptor/raptor/EmbeddingModels.py +37 -0
- baselines/raptor/raptor/FaissRetriever.py +201 -0
- baselines/raptor/raptor/QAModels.py +185 -0
- baselines/raptor/raptor/RetrievalAugmentation.py +306 -0
- baselines/raptor/raptor/Retrievers.py +8 -0
- baselines/raptor/raptor/SummarizationModels.py +74 -0
- baselines/raptor/raptor/__init__.py +16 -0
- baselines/raptor/raptor/cluster_tree_builder.py +151 -0
- baselines/raptor/raptor/cluster_utils.py +185 -0
- baselines/raptor/raptor/tree_builder.py +369 -0
- baselines/raptor/raptor/tree_retriever.py +327 -0
- baselines/raptor/raptor/tree_structures.py +28 -0
- baselines/raptor/raptor/utils.py +208 -0
- baselines/raptor/requirements.txt +11 -0
- baselines/raptor/run_raptor_baseline.py +511 -0
- baselines/read-agent/read_agent_demo.ipynb +976 -0
- baselines/read-agent/run_readagent_baseline.py +424 -0
- evaluate_qa.py +916 -0
- main.py +1717 -0
- memory/__init__.py +2 -0
- memory/episodic_store.py +62 -0
- memory/semantic_store.py +87 -0
- model_zoo.py +31 -0
- prompts/agentic_retrieval_prompt.txt +226 -0
- prompts/agentic_retrieval_prompt_wo_profile.txt +203 -0
- prompts/keyword_search_prompt.txt +31 -0
- prompts/read_and_extract_prompt.txt +176 -0
README.md
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Long-Term Memory Retrieval Benchmark
|
| 2 |
+
|
| 3 |
+
Code release for the experiments described in the accompanying paper:
|
| 4 |
+
- **Hierarchical memory** organization (User Profile / Semantic / Episodic).
|
| 5 |
+
- **Plan-Act-Read agentic retrieval** that interleaves keyword, time-filter,
|
| 6 |
+
and embedding search.
|
| 7 |
+
- **Flat / dense / oracle baselines** for comparison.
|
| 8 |
+
|
| 9 |
+
## Repository layout
|
| 10 |
+
|
| 11 |
+
```
|
| 12 |
+
.
|
| 13 |
+
├── main.py # End-to-end QA pipeline (agent, embed, keyword modes)
|
| 14 |
+
├── evaluate_qa.py # Atomic-rubric QA evaluator (strict + partial)
|
| 15 |
+
├── model_zoo.py # Model registry
|
| 16 |
+
├── prompts/ # Prompt templates
|
| 17 |
+
│ ├── agentic_retrieval_prompt.txt
|
| 18 |
+
│ ├── agentic_retrieval_prompt_wo_profile.txt
|
| 19 |
+
│ ├── keyword_search_prompt.txt
|
| 20 |
+
│ └── read_and_extract_prompt.txt
|
| 21 |
+
├── memory/ # Episodic + semantic memory stores
|
| 22 |
+
├── baselines/
|
| 23 |
+
│ ├── MemoChat/ # MemoChat baseline (upstream code + our wrapper)
|
| 24 |
+
│ ├── raptor/ # RAPTOR baseline (upstream code + our wrapper)
|
| 25 |
+
│ └── read-agent/ # ReadAgent baseline wrapper
|
| 26 |
+
├── scripts/
|
| 27 |
+
│ ├── build_retrieval_cache.py # Pre-compute GTE-7B embeddings for the corpus
|
| 28 |
+
│ ├── make_v5_shards.py # Deterministic shard split by question_id
|
| 29 |
+
│ ├── merge_jsonl_by_dataset_order.py
|
| 30 |
+
│ ├── run_oracle_qa.py # Gold-session-only upper bound
|
| 31 |
+
│ ├── plot_main_results.py
|
| 32 |
+
│ ├── llm_judge_agreement.py
|
| 33 |
+
│ └── slurm/
|
| 34 |
+
│ ├── example_dense_retrieval.slurm
|
| 35 |
+
│ └── example_agentic_retrieval.slurm
|
| 36 |
+
└── requirements.txt
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
The benchmark dataset (`evolv_mem_v5.json`) is released separately; place it
|
| 40 |
+
under `dataset/` along with the supporting files referenced by `main.py`
|
| 41 |
+
(`all_sessions.json`, `all_session_summary.json`, etc.).
|
| 42 |
+
|
| 43 |
+
## Setup
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
python -m venv .venv && source .venv/bin/activate
|
| 47 |
+
pip install -r requirements.txt
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
### API keys
|
| 51 |
+
|
| 52 |
+
The pipeline calls LLMs through three optional providers; set whichever you
|
| 53 |
+
plan to use:
|
| 54 |
+
|
| 55 |
+
| Provider | Env var | Flag |
|
| 56 |
+
|------------------------------------------------|----------------------|--------------|
|
| 57 |
+
| OpenAI-compatible inference API | `NV_API_KEY` | `--nvidia` |
|
| 58 |
+
| OpenAI-compatible LiteLLM proxy | `LITELLM_API_KEY` | `--tritonai` |
|
| 59 |
+
| Direct Anthropic API | `ANTHROPIC_API_KEY` | (default) |
|
| 60 |
+
| Azure OpenAI | `AZURE_OPENAI_KEY` | (default) |
|
| 61 |
+
|
| 62 |
+
Each `--<flag>` selects which client the pipeline uses; entries in
|
| 63 |
+
`model_zoo.py` are tagged accordingly.
|
| 64 |
+
|
| 65 |
+
## Quick start
|
| 66 |
+
|
| 67 |
+
### 1. Build the per-question retrieval cache (one-time)
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
python scripts/build_retrieval_cache.py \
|
| 71 |
+
--dataset dataset/evolv_mem_v5.json \
|
| 72 |
+
--all_sessions dataset/all_sessions.json \
|
| 73 |
+
--out_dir response_cache/retrieval/
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
### 2. Shard the dataset for parallel runs
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
python scripts/make_v5_shards.py \
|
| 80 |
+
--dataset dataset/evolv_mem_v5.json \
|
| 81 |
+
--ret_cache_jsonl response_cache/retrieval/flat-gte/v5_retrievallog_turn_flat-gte \
|
| 82 |
+
--out_dir output/shards/v5_run_nchunks10/ \
|
| 83 |
+
--num_shards 8
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
### 3. Run the QA pipeline
|
| 87 |
+
|
| 88 |
+
Flat dense retrieval @ top-k=20 (single shard, e.g. for smoke testing):
|
| 89 |
+
|
| 90 |
+
```bash
|
| 91 |
+
export ret_cache="output/shards/v5_run_nchunks10/ret_cache/shard_00.jsonl"
|
| 92 |
+
python main.py \
|
| 93 |
+
--in_file output/shards/v5_run_nchunks10/dataset/shard_00.json \
|
| 94 |
+
--out_file output/shards/v5_run_nchunks10/dense_gte_topk20/part_00.jsonl \
|
| 95 |
+
--model_name gpt-5.5 \
|
| 96 |
+
--top_k 20 \
|
| 97 |
+
--n_chunks 10 \
|
| 98 |
+
--nvidia \
|
| 99 |
+
--all_sessions_file dataset/all_sessions.json \
|
| 100 |
+
--no_semantic \
|
| 101 |
+
--mode embed
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
Agentic retrieval over hierarchical memory:
|
| 105 |
+
|
| 106 |
+
```bash
|
| 107 |
+
python main.py \
|
| 108 |
+
--in_file output/shards/v5_run_nchunks10/dataset/shard_00.json \
|
| 109 |
+
--out_file output/shards/v5_run_nchunks10/agentic_hier/part_00.jsonl \
|
| 110 |
+
--model_name gpt-5.5 \
|
| 111 |
+
--top_k 20 \
|
| 112 |
+
--n_chunks 10 \
|
| 113 |
+
--nvidia \
|
| 114 |
+
--all_sessions_file dataset/all_sessions.json \
|
| 115 |
+
--hier_v2 --hier_union \
|
| 116 |
+
--mode agent
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
To launch the full 8-shard parallel sweep on a SLURM cluster, edit and submit
|
| 120 |
+
`scripts/slurm/example_dense_retrieval.slurm` or
|
| 121 |
+
`scripts/slurm/example_agentic_retrieval.slurm`.
|
| 122 |
+
|
| 123 |
+
### 4. Merge shards and evaluate
|
| 124 |
+
|
| 125 |
+
```bash
|
| 126 |
+
python scripts/merge_jsonl_by_dataset_order.py \
|
| 127 |
+
--dataset dataset/evolv_mem_v5.json \
|
| 128 |
+
--parts_glob "output/shards/v5_run_nchunks10/dense_gte_topk20/part_*.jsonl" \
|
| 129 |
+
--out_file output/v5_run_dense_gte_topk20.jsonl
|
| 130 |
+
|
| 131 |
+
python evaluate_qa.py \
|
| 132 |
+
--hyp_file output/v5_run_dense_gte_topk20.jsonl \
|
| 133 |
+
--ref_file dataset/evolv_mem_v5.json \
|
| 134 |
+
--eval_model_name gpt-5.2 \
|
| 135 |
+
--eval_mode both \
|
| 136 |
+
--nvidia
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
The evaluator caches an atomic-rubric per question
|
| 140 |
+
(`<dataset>.atomic-v1.rubric.json`) so subsequent runs reuse it.
|
| 141 |
+
|
| 142 |
+
## Pipeline modes
|
| 143 |
+
|
| 144 |
+
`main.py --mode` selects how a question is answered:
|
| 145 |
+
|
| 146 |
+
- `embed`: top-k flat dense retrieval (GTE 7B), then a single LLM call to answer.
|
| 147 |
+
- `keyword`: LLM-generated keywords + lexical matching, then answer.
|
| 148 |
+
- `agent`: Plan-Act-Read loop. Combines `--hier_v2` (semantic-summary stage) and
|
| 149 |
+
`--hier_union` (union with flat top-K) for the hierarchical-memory variant.
|
| 150 |
+
|
| 151 |
+
`--no_semantic` disables the semantic-summary memory layer (flat memory).
|
| 152 |
+
|
| 153 |
+
## Baselines
|
| 154 |
+
|
| 155 |
+
The three external baselines (MemoChat, RAPTOR, ReadAgent) live under
|
| 156 |
+
`baselines/` together with our thin wrappers
|
| 157 |
+
(`run_<baseline>_baseline.py`). Each baseline's upstream LICENSE is preserved.
|
| 158 |
+
|
| 159 |
+
## License
|
| 160 |
+
|
| 161 |
+
This repository is released under the license stated in the corresponding
|
| 162 |
+
LICENSE file (TBD prior to release). Upstream baselines retain their original
|
| 163 |
+
licenses.
|
baselines/MemoChat/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 Lu junru
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
baselines/MemoChat/README.md
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MemoChat
|
| 2 |
+
MemoChat: Tuning LLMs to Use Memos for Consistent Long-Range Open-domain Conversation
|
| 3 |
+
|
| 4 |
+
## Environment
|
| 5 |
+
We provide [core_requirement.txt](core_requirement.txt) for your convenience.
|
| 6 |
+
|
| 7 |
+
## Model Weights
|
| 8 |
+
The initial models we used are [fastchat models (v1.3)](https://lmsys.org/blog/2023-03-30-vicuna/). Below are the model weights of our fine-tuned version. Our models are built upon Fastchat modles, thus we adopt same `cc-by-nc-sa-4.0` license.
|
| 9 |
+
|
| 10 |
+
| Name | Share Link |
|
| 11 |
+
| --- | --- |
|
| 12 |
+
| MemoChat-Fastchat-T5-3B | https://huggingface.co/Junrulu/MemoChat-Fastchat-T5-3B |
|
| 13 |
+
| MemoChat-Vicuna-7B | https://huggingface.co/Junrulu/MemoChat-Vicuna-7B |
|
| 14 |
+
| MemoChat-Vicuna-13B | https://huggingface.co/Junrulu/MemoChat-Vicuna-13B |
|
| 15 |
+
| MemoChat-Vicuna-33B | https://huggingface.co/Junrulu/MemoChat-Vicuna-33B |
|
| 16 |
+
|
| 17 |
+
## Workflow
|
| 18 |
+
`RootPath` is the absolute path of this repo. Download initial models and put them in [model](model) folder.
|
| 19 |
+
### Instruction Tuning
|
| 20 |
+
```
|
| 21 |
+
Run `bash code/scripts/tuning.sh RootPath`. Intermediate evaluation are included in this script as well.
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
### MemoChat Testing
|
| 25 |
+
```
|
| 26 |
+
Run `bash code/scripts/memochat.sh RootPath` for pipeline testing with fine-tuned models.
|
| 27 |
+
Run `bash code/scripts/memochat_gpt.sh RootPath` for pipeline testing with GPT3.5 API.
|
| 28 |
+
Run `bash code/scripts/llm_judge.sh RootPath` for GPT4 judge (openai api is required).
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
### Our Results
|
| 32 |
+
We provide our prediction results [here](https://drive.google.com/file/d/1jGNhT3iPXEA8B2fXHZ2Einy1AMre-8xB/view?usp=sharing).
|
| 33 |
+
|
| 34 |
+
## Acknowledgement
|
| 35 |
+
We thank [Vicuna project](https://github.com/lm-sys/FastChat/tree/main) for their great work.
|
| 36 |
+
|
| 37 |
+
## Citation
|
| 38 |
+
```
|
| 39 |
+
@misc{lu2023memochat,
|
| 40 |
+
title={MemoChat: Tuning LLMs to Use Memos for Consistent Long-Range Open-Domain Conversation},
|
| 41 |
+
author={Junru Lu and Siyu An and Mingbao Lin and Gabriele Pergola and Yulan He and Di Yin and Xing Sun and Yunsheng Wu},
|
| 42 |
+
year={2023},
|
| 43 |
+
eprint={2308.08239},
|
| 44 |
+
archivePrefix={arXiv},
|
| 45 |
+
primaryClass={cs.CL}
|
| 46 |
+
}
|
| 47 |
+
```
|
baselines/MemoChat/code/codes/api/gpt_2k.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import json
|
| 3 |
+
import openai
|
| 4 |
+
import time
|
| 5 |
+
import sys
|
| 6 |
+
import tiktoken
|
| 7 |
+
|
| 8 |
+
input_data = sys.argv[1]
|
| 9 |
+
openai_modelid = sys.argv[2]
|
| 10 |
+
openai.api_key = sys.argv[3]
|
| 11 |
+
output_path = sys.argv[4]
|
| 12 |
+
prompt_path = sys.argv[5]
|
| 13 |
+
encoding = tiktoken.encoding_for_model(openai_modelid)
|
| 14 |
+
|
| 15 |
+
q_pre = ""
|
| 16 |
+
qa_link = ""
|
| 17 |
+
MaxLen = 2048
|
| 18 |
+
TarLen = 512
|
| 19 |
+
TaskTarLen = {
|
| 20 |
+
"chatting_dialogsum": MaxLen,
|
| 21 |
+
"chatting_alpacagpt4": MaxLen,
|
| 22 |
+
"writing_topiocqa": TarLen // 2,
|
| 23 |
+
"writing_dialogsum": TarLen,
|
| 24 |
+
"retrieval_dialogsum": 32,
|
| 25 |
+
"retrieval_topiocqa": 32
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
prompts = json.load(open(prompt_path, "r"))
|
| 29 |
+
|
| 30 |
+
def normalize_chatting_outputs(model_outputs):
|
| 31 |
+
def white_space_fix(text):
|
| 32 |
+
lines = text.split("\n")
|
| 33 |
+
result = []
|
| 34 |
+
for line in lines:
|
| 35 |
+
result.append(' '.join(line.split()))
|
| 36 |
+
output = '\n'.join(result)
|
| 37 |
+
return output
|
| 38 |
+
return white_space_fix(model_outputs)
|
| 39 |
+
|
| 40 |
+
def gen_model_output(input_qs, task_type):
|
| 41 |
+
input_qs_token_l = len(encoding.encode(input_qs)) # token num
|
| 42 |
+
input_qs_word_l = len(input_qs.split(" ")) # word num
|
| 43 |
+
qs_w_t_ratio = input_qs_word_l / input_qs_token_l
|
| 44 |
+
max_word_num = int((MaxLen - TarLen) * qs_w_t_ratio)
|
| 45 |
+
input_qs = " ".join(input_qs.split(" ")[-max_word_num:])
|
| 46 |
+
target_len = TaskTarLen[task_type]
|
| 47 |
+
messages = [{"role": "system", "content": input_qs}]
|
| 48 |
+
for _ in range(5):
|
| 49 |
+
try:
|
| 50 |
+
chat = openai.ChatCompletion.create(
|
| 51 |
+
model=openai_modelid, messages=messages, max_tokens=target_len, temperature=0.2
|
| 52 |
+
)
|
| 53 |
+
break
|
| 54 |
+
except:
|
| 55 |
+
time.sleep(5)
|
| 56 |
+
model_outputs = chat.choices[0].message.content
|
| 57 |
+
return model_outputs
|
| 58 |
+
|
| 59 |
+
def run_eval():
|
| 60 |
+
data = json.load(open(input_data, "r"))
|
| 61 |
+
output_data = []
|
| 62 |
+
for d in data:
|
| 63 |
+
print("=" * 20 + "start of question {}".format(d["id"]) + "=" * 20)
|
| 64 |
+
new_d = d
|
| 65 |
+
|
| 66 |
+
history = []
|
| 67 |
+
for l_i in range(len(new_d["conversations"])):
|
| 68 |
+
if l_i % 2 == 1:
|
| 69 |
+
bot_thinking = {"retrieval": "", "summarization": ""}
|
| 70 |
+
print("=" * 20 + "start of turn {}".format(l_i // 2 + 1) + "=" * 20)
|
| 71 |
+
user = "user: " + new_d["conversations"][l_i - 1]["value"]
|
| 72 |
+
|
| 73 |
+
system_insturction = prompts["chatting"]["system"]
|
| 74 |
+
task_instruction = prompts["chatting"]["instruction"]
|
| 75 |
+
task_case = "```\nRecent Dialogs:\n" + " ### ".join([hrd.replace("\n", " ") for hrd in history]) + "\n```\n\nUser Input:\n" + user + " ### bot: "
|
| 76 |
+
qs = system_insturction + task_case + task_instruction
|
| 77 |
+
print(qs + "\n\n")
|
| 78 |
+
outputs = gen_model_output(qs, "chatting_dialogsum")
|
| 79 |
+
outputs = normalize_chatting_outputs(outputs)
|
| 80 |
+
history += [user, "bot: " + outputs]
|
| 81 |
+
print("bot: " + outputs + "\n")
|
| 82 |
+
print("=" * 20 + "end of turn {}".format(l_i // 2 + 1) + "=" * 20)
|
| 83 |
+
new_d["conversations"][l_i]["thinking"] = json.dumps(bot_thinking)
|
| 84 |
+
new_d["conversations"][l_i]["value"] = outputs
|
| 85 |
+
|
| 86 |
+
output_data.append(new_d)
|
| 87 |
+
json.dump(output_data, open(output_path, "w"), indent=2)
|
| 88 |
+
|
| 89 |
+
if __name__ == "__main__":
|
| 90 |
+
run_eval()
|
baselines/MemoChat/code/codes/api/gpt_memochat.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import json
|
| 3 |
+
import openai
|
| 4 |
+
import time
|
| 5 |
+
import sys
|
| 6 |
+
import tiktoken
|
| 7 |
+
from random import sample
|
| 8 |
+
|
| 9 |
+
input_data = sys.argv[1]
|
| 10 |
+
openai_modelid = sys.argv[2]
|
| 11 |
+
openai.api_key = sys.argv[3]
|
| 12 |
+
output_path = sys.argv[4]
|
| 13 |
+
prompt_path = sys.argv[5]
|
| 14 |
+
encoding = tiktoken.encoding_for_model(openai_modelid)
|
| 15 |
+
|
| 16 |
+
q_pre = ""
|
| 17 |
+
qa_link = ""
|
| 18 |
+
MaxLen = 2048
|
| 19 |
+
TarLen = 512
|
| 20 |
+
TaskTarLen = {
|
| 21 |
+
"chatting_dialogsum": MaxLen,
|
| 22 |
+
"chatting_alpacagpt4": MaxLen,
|
| 23 |
+
"writing_topiocqa": TarLen // 2,
|
| 24 |
+
"writing_dialogsum": TarLen,
|
| 25 |
+
"retrieval_dialogsum": 32,
|
| 26 |
+
"retrieval_topiocqa": 32
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
prompts = json.load(open(prompt_path, "r"))
|
| 30 |
+
|
| 31 |
+
def normalize_model_outputs(model_text):
|
| 32 |
+
extracted_elements = [re.sub(r'\s+', ' ', mt.replace('"', '').replace("'", "")) for mt in re.findall(r"'[^']*'|\"[^\"]*\"|\d+", model_text)]
|
| 33 |
+
model_outputs = []
|
| 34 |
+
ti = 0
|
| 35 |
+
while ti + 7 < len(extracted_elements):
|
| 36 |
+
if extracted_elements[ti] == "topic" and extracted_elements[ti + 2] == "summary" and extracted_elements[ti + 4] == "start" and extracted_elements[ti + 6] == "end":
|
| 37 |
+
try:
|
| 38 |
+
model_outputs.append({"topic": extracted_elements[ti + 1], "summary": extracted_elements[ti + 3], "start": int(extracted_elements[ti + 5]), "end": int(extracted_elements[ti + 7])})
|
| 39 |
+
except:
|
| 40 |
+
pass
|
| 41 |
+
ti += 1
|
| 42 |
+
return model_outputs
|
| 43 |
+
|
| 44 |
+
def normalize_chatting_outputs(model_outputs):
|
| 45 |
+
def white_space_fix(text):
|
| 46 |
+
lines = text.split("\n")
|
| 47 |
+
result = []
|
| 48 |
+
for line in lines:
|
| 49 |
+
result.append(' '.join(line.split()))
|
| 50 |
+
output = '\n'.join(result)
|
| 51 |
+
return output
|
| 52 |
+
return white_space_fix(model_outputs)
|
| 53 |
+
|
| 54 |
+
def gen_model_output(input_qs, task_type):
|
| 55 |
+
input_qs_token_l = len(encoding.encode(input_qs)) # token num
|
| 56 |
+
input_qs_word_l = len(input_qs.split(" ")) # word num
|
| 57 |
+
qs_w_t_ratio = input_qs_word_l / input_qs_token_l
|
| 58 |
+
max_word_num = int((MaxLen - TarLen) * qs_w_t_ratio)
|
| 59 |
+
input_qs = " ".join(input_qs.split(" ")[-max_word_num:])
|
| 60 |
+
target_len = TaskTarLen[task_type]
|
| 61 |
+
messages = [{"role": "system", "content": input_qs}]
|
| 62 |
+
for _ in range(5):
|
| 63 |
+
try:
|
| 64 |
+
chat = openai.ChatCompletion.create(
|
| 65 |
+
model=openai_modelid, messages=messages, max_tokens=target_len, temperature=0.2
|
| 66 |
+
)
|
| 67 |
+
break
|
| 68 |
+
except:
|
| 69 |
+
time.sleep(5)
|
| 70 |
+
model_outputs = chat.choices[0].message.content
|
| 71 |
+
return model_outputs
|
| 72 |
+
|
| 73 |
+
def run_summary(history, memo, bot_thinking):
|
| 74 |
+
system_insturction = prompts["writing_dialogsum"]["system"]
|
| 75 |
+
task_instruction = prompts["writing_dialogsum"]["instruction"]
|
| 76 |
+
history_log = "\n\n```\nTask Conversation:\n" + "\n".join(["(line {}) {}".format(h_i + 1, h.replace("\n", " ")) for h_i, h in enumerate(history["Recent Dialogs"][2:])])
|
| 77 |
+
qs = q_pre + system_insturction.replace("LINE", str(len(history["Recent Dialogs"]) - 2)) + history_log + "\n```" + task_instruction.replace("LINE", str(len(history["Recent Dialogs"]) - 2)) + qa_link
|
| 78 |
+
# print("-" * 20 + "summarizing" + "-" * 20)
|
| 79 |
+
# print(qs)
|
| 80 |
+
# print("-" * 20 + "summarizing" + "-" * 20)
|
| 81 |
+
sum_history = gen_model_output(qs, "writing_dialogsum")
|
| 82 |
+
sum_history = normalize_model_outputs(sum_history)
|
| 83 |
+
# print("-" * 20 + "summarization" + "-" * 20)
|
| 84 |
+
# print(sum_history)
|
| 85 |
+
# print("-" * 20 + "summarization" + "-" * 20)
|
| 86 |
+
for s in sum_history:
|
| 87 |
+
memo[s["topic"]] = memo.get(s["topic"], []) + [{"summary": s["summary"], "dialogs": history["Recent Dialogs"][2:][(s["start"] - 1):s["end"]]}]
|
| 88 |
+
if len(sum_history) == 0:
|
| 89 |
+
si_0, si_1 = sample(list(range(len(history["Recent Dialogs"][2:]))), 2)
|
| 90 |
+
memo["NOTO"].append({"summary": "Partial dialogs about: {} or {}.".format(history["Recent Dialogs"][2:][si_0], history["Recent Dialogs"][2:][si_1]), "dialogs": history["Recent Dialogs"][2:]})
|
| 91 |
+
history["Recent Dialogs"] = history["Recent Dialogs"][-2:]
|
| 92 |
+
bot_thinking["summarization"] = {"input": qs, "output": sum_history}
|
| 93 |
+
return history, memo, bot_thinking
|
| 94 |
+
|
| 95 |
+
def run_retrieval(history, memo, bot_thinking):
|
| 96 |
+
topics = []
|
| 97 |
+
for k, v in memo.items():
|
| 98 |
+
for vv in v:
|
| 99 |
+
topics.append((k, vv["summary"], vv["dialogs"]))
|
| 100 |
+
system_insturction = prompts["retrieval"]["system"]
|
| 101 |
+
task_instruction = prompts["retrieval"]["instruction"]
|
| 102 |
+
task_case = "```\nQuery Sentence:\n" + history["User Input"][6:] + "\nTopic Options:\n" + \
|
| 103 |
+
"\n".join(["({}) {}".format(v_i + 1, v[0] + ". " + v[1]) for v_i, v in enumerate(topics)]) + "\n```"
|
| 104 |
+
qs = q_pre + system_insturction.replace("OPTION", str(len(topics))) + task_case + task_instruction.replace("OPTION", str(len(topics))) + qa_link
|
| 105 |
+
# print("-" * 20 + "retrieving" + "-" * 20)
|
| 106 |
+
# print(qs)
|
| 107 |
+
# print("-" * 20 + "retrieving" + "-" * 20)
|
| 108 |
+
outputs = gen_model_output(qs, "retrieval_dialogsum")
|
| 109 |
+
# print("-" * 20 + "retrieval" + "-" * 20)
|
| 110 |
+
# print(outputs)
|
| 111 |
+
# print("-" * 20 + "retrieval" + "-" * 20)
|
| 112 |
+
outputs = outputs.split("#")
|
| 113 |
+
chosen_topics = []
|
| 114 |
+
for output in outputs:
|
| 115 |
+
try:
|
| 116 |
+
index_ = int(output) - 1
|
| 117 |
+
except:
|
| 118 |
+
continue
|
| 119 |
+
if index_ < len(topics) and "NOTO" not in topics[index_]:
|
| 120 |
+
chosen_topics.append(topics[index_])
|
| 121 |
+
if len(chosen_topics) > 0:
|
| 122 |
+
history["Related Topics"] = [ct[0] for ct in chosen_topics]
|
| 123 |
+
history["Related Summaries"] = [ct[1] for ct in chosen_topics]
|
| 124 |
+
history["Related Dialogs"] = [" ### ".join(ct[2]) for ct in chosen_topics]
|
| 125 |
+
else:
|
| 126 |
+
history["Related Topics"] = []
|
| 127 |
+
history["Related Summaries"] = []
|
| 128 |
+
history["Related Dialogs"] = []
|
| 129 |
+
bot_thinking["retrieval"] = {"input": qs, "output": outputs}
|
| 130 |
+
return history, bot_thinking
|
| 131 |
+
|
| 132 |
+
def run_eval():
|
| 133 |
+
data = json.load(open(input_data, "r"))
|
| 134 |
+
output_data = []
|
| 135 |
+
for d in data:
|
| 136 |
+
print("=" * 20 + "start of question {}".format(d["id"]) + "=" * 20)
|
| 137 |
+
new_d = d
|
| 138 |
+
|
| 139 |
+
history = {
|
| 140 |
+
"Recent Dialogs": ["user: Hi!", "bot: Hi! How can I help you today?"],
|
| 141 |
+
"Related Topics": [],
|
| 142 |
+
"Related Summaries": [],
|
| 143 |
+
"Related Dialogs": [],
|
| 144 |
+
"User Input": "",
|
| 145 |
+
}
|
| 146 |
+
memo = {
|
| 147 |
+
"NOTO": [{"summary": "None of the others.", "dialogs": []}]
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
for l_i in range(len(new_d["conversations"])):
|
| 151 |
+
if l_i % 2 == 1:
|
| 152 |
+
bot_thinking = {"retrieval": "", "summarization": ""}
|
| 153 |
+
print("=" * 20 + "start of turn {}".format(l_i // 2 + 1) + "=" * 20)
|
| 154 |
+
user = "user: " + new_d["conversations"][l_i - 1]["value"]
|
| 155 |
+
print(user + "\n\n")
|
| 156 |
+
|
| 157 |
+
# create summary if recent dialogs exceed threshold
|
| 158 |
+
if len(" ### ".join(history["Recent Dialogs"]).split(" ")) > (MaxLen // 2) or len(history["Recent Dialogs"]) >= 10:
|
| 159 |
+
history, memo, bot_thinking = run_summary(history, memo, bot_thinking)
|
| 160 |
+
|
| 161 |
+
# retrieve most related topics for every new user input
|
| 162 |
+
history["User Input"] = user
|
| 163 |
+
if len(memo.keys()) > 1:
|
| 164 |
+
history, bot_thinking = run_retrieval(history, memo, bot_thinking)
|
| 165 |
+
|
| 166 |
+
# generate bot response
|
| 167 |
+
system_insturction = prompts["chatting"]["system"]
|
| 168 |
+
task_instruction = prompts["chatting"]["instruction"]
|
| 169 |
+
task_case = "```\nRelated Evidences:\n" + "\n".join(["({}) {}".format(r_tsd_i + 1, {
|
| 170 |
+
"Related Topics": history["Related Topics"][r_tsd_i],
|
| 171 |
+
"Related Summaries": history["Related Summaries"][r_tsd_i],
|
| 172 |
+
"Related Dialogs": history["Related Dialogs"][r_tsd_i]
|
| 173 |
+
}) for r_tsd_i in range(len(history["Related Topics"]))]) + "\n\nRecent Dialogs:\n" + \
|
| 174 |
+
" ### ".join([hrd.replace("\n", " ") for hrd in history["Recent Dialogs"]]) + "\n```\n\nUser Input:\n" + history["User Input"] + " ### bot: "
|
| 175 |
+
qs = q_pre + system_insturction + task_case + task_instruction + qa_link
|
| 176 |
+
outputs = gen_model_output(qs, "chatting_dialogsum")
|
| 177 |
+
outputs = normalize_chatting_outputs(outputs)
|
| 178 |
+
history["Recent Dialogs"] += [user, "bot: " + outputs]
|
| 179 |
+
print("bot: " + outputs + "\n")
|
| 180 |
+
print("=" * 20 + "end of turn {}".format(l_i // 2 + 1) + "=" * 20)
|
| 181 |
+
# print("\n\n\n\n")
|
| 182 |
+
new_d["conversations"][l_i]["thinking"] = json.dumps(bot_thinking)
|
| 183 |
+
new_d["conversations"][l_i]["value"] = outputs
|
| 184 |
+
|
| 185 |
+
output_data.append(new_d)
|
| 186 |
+
json.dump(output_data, open(output_path, "w"), indent=2)
|
| 187 |
+
|
| 188 |
+
if __name__ == "__main__":
|
| 189 |
+
run_eval()
|
baselines/MemoChat/code/codes/api/llm_judge.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import re
|
| 5 |
+
import openai
|
| 6 |
+
import time
|
| 7 |
+
import tiktoken
|
| 8 |
+
|
| 9 |
+
input_data = sys.argv[1]
|
| 10 |
+
openai_modelid = sys.argv[2]
|
| 11 |
+
openai.api_key = sys.argv[3]
|
| 12 |
+
output_path = sys.argv[4]
|
| 13 |
+
prompt_path = sys.argv[5]
|
| 14 |
+
encoding = tiktoken.encoding_for_model(openai_modelid)
|
| 15 |
+
|
| 16 |
+
prompts = json.load(open(prompt_path, "r"))
|
| 17 |
+
judge_prompt_raw = prompts["judge"]["system"]
|
| 18 |
+
|
| 19 |
+
def gen_model_output(input_qs):
|
| 20 |
+
input_qs_token_l = len(encoding.encode(input_qs)) # token num
|
| 21 |
+
input_qs_word_l = len(input_qs.split(" ")) # word num
|
| 22 |
+
qs_w_t_ratio = input_qs_word_l / input_qs_token_l
|
| 23 |
+
max_word_num = int(4096 * qs_w_t_ratio)
|
| 24 |
+
input_qs = " ".join(input_qs.split(" ")[-max_word_num:])
|
| 25 |
+
messages = [{"role": "system", "content": input_qs}]
|
| 26 |
+
chat = None
|
| 27 |
+
for _ in range(5):
|
| 28 |
+
try:
|
| 29 |
+
chat = openai.ChatCompletion.create(
|
| 30 |
+
model=openai_modelid, messages=messages
|
| 31 |
+
)
|
| 32 |
+
break
|
| 33 |
+
except:
|
| 34 |
+
time.sleep(5)
|
| 35 |
+
if chat is None:
|
| 36 |
+
return "Cannot generate output."
|
| 37 |
+
model_outputs = chat.choices[0].message.content
|
| 38 |
+
return model_outputs
|
| 39 |
+
|
| 40 |
+
data = json.load(open(input_data, "r"))
|
| 41 |
+
|
| 42 |
+
# do llm judge
|
| 43 |
+
output_ratings = []
|
| 44 |
+
for d in data:
|
| 45 |
+
print("=" * 20 + "Processing: " + d["id"] + "=" * 20)
|
| 46 |
+
judge_conversation = []
|
| 47 |
+
d_conversations = d['conversations']
|
| 48 |
+
last_q = d_conversations[-2]
|
| 49 |
+
turn_infos = last_q["turn-info"].split("-")
|
| 50 |
+
r_turns = [turn_infos[0] + "-" + turn_infos[1]]
|
| 51 |
+
if len(turn_infos) == 5:
|
| 52 |
+
r_turns.append(turn_infos[2] + "-" + turn_infos[3])
|
| 53 |
+
for l_i in range(len(d_conversations) // 2 - 1):
|
| 54 |
+
if d_conversations[l_i * 2]["turn-info"][:-2] in r_turns:
|
| 55 |
+
judge_conversation.append("user: " + d_conversations[l_i * 2]["value"])
|
| 56 |
+
judge_conversation.append("bot: " + d_conversations[l_i * 2 + 1]["value"])
|
| 57 |
+
judge_prompt = judge_prompt_raw.replace("RCH_0", "\n".join(judge_conversation)).replace("UQ_1", "user: " + last_q["value"]).replace("BR_2", "bot: " + d_conversations[-1]["value"])
|
| 58 |
+
print(judge_prompt)
|
| 59 |
+
print('-' * 20)
|
| 60 |
+
outputs = gen_model_output(judge_prompt)
|
| 61 |
+
print(outputs)
|
| 62 |
+
print("=" * 20 + "Processed: " + d["id"] + "=" * 20)
|
| 63 |
+
match = re.search(r'\[\[(\d+)\]\]', outputs)
|
| 64 |
+
try:
|
| 65 |
+
rating = int(match.group(1))
|
| 66 |
+
except:
|
| 67 |
+
rating = None
|
| 68 |
+
output_ratings.append({
|
| 69 |
+
"id": d["id"],
|
| 70 |
+
"type": d["type"],
|
| 71 |
+
"judge_prompt": judge_prompt,
|
| 72 |
+
"evaluation": outputs,
|
| 73 |
+
"rating": rating
|
| 74 |
+
})
|
| 75 |
+
json.dump(output_ratings, open(output_path, "w"), indent=2)
|
| 76 |
+
|
| 77 |
+
# compute score
|
| 78 |
+
count = {
|
| 79 |
+
"continuation": [],
|
| 80 |
+
"retrospection": [],
|
| 81 |
+
"conjunction": []
|
| 82 |
+
}
|
| 83 |
+
for d in output_ratings:
|
| 84 |
+
if d["type"] == "continuation":
|
| 85 |
+
count["continuation"].append(d["rating"])
|
| 86 |
+
elif d["type"] == "retrospection":
|
| 87 |
+
count["retrospection"].append(d["rating"])
|
| 88 |
+
elif d["type"] == "conjunction":
|
| 89 |
+
count["conjunction"].append(d["rating"])
|
| 90 |
+
print("Retrospection Score: {}, Continuation Score: {}, Conjunction Score: {}, Overall Score: {} of file {}".format(
|
| 91 |
+
round(sum(count["retrospection"]) / len(count["retrospection"]), 2),
|
| 92 |
+
round(sum(count["continuation"]) / len(count["continuation"]), 2),
|
| 93 |
+
round(sum(count["conjunction"]) / len(count["conjunction"]), 2),
|
| 94 |
+
round(sum(count["continuation"] + count["retrospection"] + count["conjunction"]) / len(count["continuation"] + count["retrospection"] + count["conjunction"]), 2),
|
| 95 |
+
input_data
|
| 96 |
+
))
|
baselines/MemoChat/code/codes/eval/eval_instruction_tuning_tasks.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
import string
|
| 4 |
+
import sys
|
| 5 |
+
import random
|
| 6 |
+
from argparse import ArgumentParser
|
| 7 |
+
from collections import Counter
|
| 8 |
+
|
| 9 |
+
from evaluate import load
|
| 10 |
+
bertscore = load("bertscore")
|
| 11 |
+
|
| 12 |
+
refer_file_path = sys.argv[1]
|
| 13 |
+
input_file_path = sys.argv[2]
|
| 14 |
+
|
| 15 |
+
conversations = open(refer_file_path, "r").readlines()
|
| 16 |
+
conversations_dict = {}
|
| 17 |
+
for conversation in conversations:
|
| 18 |
+
conv_l = json.loads(conversation.strip())
|
| 19 |
+
conversations_dict[conv_l["question_id"]] = (conv_l["text"], conv_l["answer"], conv_l["type"])
|
| 20 |
+
|
| 21 |
+
class Metrics():
|
| 22 |
+
def __init__(self):
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
def __normalize_text(self, s_text):
|
| 26 |
+
"""Lower text and remove punctuation, storys and extra whitespace."""
|
| 27 |
+
def remove_articles(text):
|
| 28 |
+
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
|
| 29 |
+
return re.sub(regex, ' ', text)
|
| 30 |
+
|
| 31 |
+
def white_space_fix(text):
|
| 32 |
+
return ' '.join(text.split())
|
| 33 |
+
|
| 34 |
+
def remove_punc(text):
|
| 35 |
+
exclude = set(string.punctuation)
|
| 36 |
+
return ''.join(ch for ch in text if ch not in exclude)
|
| 37 |
+
|
| 38 |
+
def lower(text):
|
| 39 |
+
return text.lower()
|
| 40 |
+
|
| 41 |
+
return white_space_fix(remove_articles(remove_punc(lower(s_text))))
|
| 42 |
+
|
| 43 |
+
def __normalize_model_outputs(self, model_text, type_category):
|
| 44 |
+
"""post process of memo writing outputs"""
|
| 45 |
+
extracted_elements = [re.sub(r'\s+', ' ', mt.replace('"', '').replace("'", "")) for mt in re.findall(r"'[^']*'|\"[^\"]*\"|\d+", model_text)]
|
| 46 |
+
model_outputs = []
|
| 47 |
+
ti = 0
|
| 48 |
+
|
| 49 |
+
if "dialogsum" in type_category:
|
| 50 |
+
while ti + 7 < len(extracted_elements):
|
| 51 |
+
if extracted_elements[ti] == "topic" and extracted_elements[ti + 2] == "summary" and extracted_elements[ti + 4] == "start" and extracted_elements[ti + 6] == "end":
|
| 52 |
+
try:
|
| 53 |
+
model_outputs.append({"topic": extracted_elements[ti + 1], "summary": extracted_elements[ti + 3], "start": int(extracted_elements[ti + 5]), "end": int(extracted_elements[ti + 7])})
|
| 54 |
+
except:
|
| 55 |
+
pass
|
| 56 |
+
ti += 1
|
| 57 |
+
else:
|
| 58 |
+
while ti + 5 < len(extracted_elements):
|
| 59 |
+
if extracted_elements[ti] == "topic" and extracted_elements[ti + 2] == "start" and extracted_elements[ti + 4] == "end":
|
| 60 |
+
try:
|
| 61 |
+
model_outputs.append({"topic": extracted_elements[ti + 1], "start": int(extracted_elements[ti + 3]), "end": int(extracted_elements[ti + 5])})
|
| 62 |
+
except:
|
| 63 |
+
pass
|
| 64 |
+
ti += 1
|
| 65 |
+
|
| 66 |
+
return model_outputs
|
| 67 |
+
|
| 68 |
+
def __get_class_span_dict__(self, label, checkitem_k):
|
| 69 |
+
class_span = {}
|
| 70 |
+
for i in range(len(label)):
|
| 71 |
+
checkitem_i = self.__normalize_text(label[i][checkitem_k])
|
| 72 |
+
class_span[(label[i]['start'], label[i]['end'])] = class_span.get((label[i]['start'], label[i]['end']), []) + [checkitem_i]
|
| 73 |
+
return class_span
|
| 74 |
+
|
| 75 |
+
def __get_intersect_by_entity__(self, pred_class_span, label_class_span):
|
| 76 |
+
'''
|
| 77 |
+
return the count of correct entity
|
| 78 |
+
'''
|
| 79 |
+
cnt = 0
|
| 80 |
+
for label in label_class_span:
|
| 81 |
+
cnt += len(list(set(label_class_span[label]).intersection(set(pred_class_span.get(label,[])))))
|
| 82 |
+
return cnt
|
| 83 |
+
|
| 84 |
+
def __get_bertscore_by_entity__(self, pred_class_span, label_class_span):
|
| 85 |
+
'''
|
| 86 |
+
return the count of correct entity
|
| 87 |
+
'''
|
| 88 |
+
cnt = 0
|
| 89 |
+
for label in label_class_span:
|
| 90 |
+
if label in pred_class_span:
|
| 91 |
+
references = [label_class_span[label]]
|
| 92 |
+
prediction = [pred_class_span[label][0]]
|
| 93 |
+
result = bertscore.compute(predictions=prediction, references=references, model_type="microsoft/deberta-xlarge-mnli")["precision"][0]
|
| 94 |
+
cnt += result
|
| 95 |
+
return cnt
|
| 96 |
+
|
| 97 |
+
def __get_cnt__(self, label_class_span):
|
| 98 |
+
'''
|
| 99 |
+
return the count of entities
|
| 100 |
+
'''
|
| 101 |
+
cnt = 0
|
| 102 |
+
for label in label_class_span:
|
| 103 |
+
cnt += len(label_class_span[label])
|
| 104 |
+
# cnt += 1 # set as 1 if we have multiple references
|
| 105 |
+
return cnt
|
| 106 |
+
|
| 107 |
+
def metrics_by_entity_(self, pred, label, checkitem_k):
|
| 108 |
+
'''
|
| 109 |
+
return entity level count of total prediction, true labels, and correct prediction
|
| 110 |
+
'''
|
| 111 |
+
pred_class_span = self.__get_class_span_dict__(pred, checkitem_k)
|
| 112 |
+
label_class_span = self.__get_class_span_dict__(label, checkitem_k)
|
| 113 |
+
pred_cnt = self.__get_cnt__(pred_class_span)
|
| 114 |
+
label_cnt = self.__get_cnt__(label_class_span)
|
| 115 |
+
if checkitem_k == "topic":
|
| 116 |
+
correct_cnt = self.__get_intersect_by_entity__(pred_class_span, label_class_span)
|
| 117 |
+
elif checkitem_k == "summary":
|
| 118 |
+
correct_cnt = self.__get_bertscore_by_entity__(pred_class_span, label_class_span)
|
| 119 |
+
return pred_cnt, label_cnt, correct_cnt
|
| 120 |
+
|
| 121 |
+
def p_r_f1_by_entity(self, pc, lc, cc):
|
| 122 |
+
precision = cc / (pc + 1e-8)
|
| 123 |
+
recall = cc / (lc + 1e-8)
|
| 124 |
+
f1 = 2 * precision * recall / (precision + recall + 1e-8)
|
| 125 |
+
return round(precision * 100, 2), round(recall * 100, 2), round(f1 * 100, 2)
|
| 126 |
+
|
| 127 |
+
def metrics_by_entity_files(self, pred_file, checkitem_k, type_key):
|
| 128 |
+
pred_cnt = 0
|
| 129 |
+
label_cnt = 0
|
| 130 |
+
correct_cnt = 0
|
| 131 |
+
for l_i, line in enumerate(open(pred_file, "r").readlines()):
|
| 132 |
+
eles = json.loads(line.strip())
|
| 133 |
+
|
| 134 |
+
if (type_key not in conversations_dict[eles["question_id"]][2]) or (conversations_dict[eles["question_id"]][2] == "writing_topiocqa" and checkitem_k == "summary"):
|
| 135 |
+
continue
|
| 136 |
+
if type_key == "writing":
|
| 137 |
+
model_text = self.__normalize_model_outputs(eles["text"], conversations_dict[eles["question_id"]][2])
|
| 138 |
+
label_i = json.loads(conversations_dict[eles["question_id"]][1])
|
| 139 |
+
elif type_key == "retrieval":
|
| 140 |
+
model_text = [{"topic": v, "start": 0, "end": 0} for v in set(eles["text"].split("#"))]
|
| 141 |
+
label_i = [{"topic": v, "start": 0, "end": 0} for v in set(conversations_dict[eles["question_id"]][1].split("#"))]
|
| 142 |
+
else:
|
| 143 |
+
model_text = [{"summary": eles["text"], "start": 0, "end": 0}]
|
| 144 |
+
label_i = [{"summary": conversations_dict[eles["question_id"]][1], "start": 0, "end": 0}]
|
| 145 |
+
|
| 146 |
+
p_cnt, l_cnt, c_cnt = self.metrics_by_entity_(model_text, label_i, checkitem_k)
|
| 147 |
+
p_i, r_i, f_i = self.p_r_f1_by_entity(p_cnt, l_cnt, c_cnt)
|
| 148 |
+
# if p_i + r_i + f_i != 0:
|
| 149 |
+
# print("Q ID: " + str(eles["question_id"]) + "\n")
|
| 150 |
+
# print(conversations_dict[eles["question_id"]][0] + "\n")
|
| 151 |
+
# # print("Raw Ouput: " + eles["text"] + "\n")
|
| 152 |
+
# print("Model: {}".format(model_text) + "\n")
|
| 153 |
+
# print("Refer: {}".format(label_i) + "\n")
|
| 154 |
+
# print("Case P/R/F1: {}%, {}%, {}%".format(p_i, r_i, f_i))
|
| 155 |
+
# print("=" * 20)
|
| 156 |
+
pred_cnt += p_cnt
|
| 157 |
+
label_cnt += l_cnt
|
| 158 |
+
correct_cnt += c_cnt
|
| 159 |
+
return self.p_r_f1_by_entity(pred_cnt, label_cnt, correct_cnt)
|
| 160 |
+
|
| 161 |
+
calculate_metrics = Metrics()
|
| 162 |
+
p_a, r_a, f1_a = calculate_metrics.metrics_by_entity_files(input_file_path, 'topic', 'writing') # both
|
| 163 |
+
print("Overall P/R/F1 of topic: {}%, {}%, {}%".format(p_a, r_a, f1_a))
|
| 164 |
+
p_b, r_b, f1_b = calculate_metrics.metrics_by_entity_files(input_file_path, 'summary', 'writing') # dialogsum
|
| 165 |
+
print("Overall P/R/F1 of summary: {}%, {}%, {}%".format(p_b, r_b, f1_b))
|
| 166 |
+
_, _, f1 = calculate_metrics.metrics_by_entity_files(input_file_path, "topic", "retrieval") # both
|
| 167 |
+
print("Retrival F1: {}%".format(f1))
|
| 168 |
+
p, _, _ = calculate_metrics.metrics_by_entity_files(input_file_path, "summary", "chatting") # dialogsum
|
| 169 |
+
print("Chatting similarity: {}%".format(p))
|
baselines/MemoChat/code/codes/eval/get_model_infer_memochat.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, AutoModelForSeq2SeqLM
|
| 3 |
+
from optimum.bettertransformer import BetterTransformer
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import re
|
| 8 |
+
import ray
|
| 9 |
+
import warnings
|
| 10 |
+
from random import sample
|
| 11 |
+
warnings.filterwarnings("ignore")
|
| 12 |
+
|
| 13 |
+
q_pre = "<s>\n"
|
| 14 |
+
qa_link = "\n"
|
| 15 |
+
MaxLen = 2048
|
| 16 |
+
TarLen = 512
|
| 17 |
+
TaskTarLen = {
|
| 18 |
+
"chatting_dialogsum": MaxLen,
|
| 19 |
+
"chatting_alpacagpt4": MaxLen,
|
| 20 |
+
"writing_topiocqa": TarLen // 2,
|
| 21 |
+
"writing_dialogsum": TarLen,
|
| 22 |
+
"retrieval_dialogsum": 32,
|
| 23 |
+
"retrieval_topiocqa": 32
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
def get_gpu_memory(num_gpus):
|
| 27 |
+
"""Get available memory for each GPU."""
|
| 28 |
+
gpu_memory = []
|
| 29 |
+
for gpu_id in range(num_gpus):
|
| 30 |
+
with torch.cuda.device(gpu_id):
|
| 31 |
+
device = torch.cuda.current_device()
|
| 32 |
+
gpu_properties = torch.cuda.get_device_properties(device)
|
| 33 |
+
total_memory = gpu_properties.total_memory / (1024**3)
|
| 34 |
+
allocated_memory = torch.cuda.memory_allocated() / (1024**3)
|
| 35 |
+
available_memory = total_memory - allocated_memory
|
| 36 |
+
gpu_memory.append(available_memory)
|
| 37 |
+
return gpu_memory
|
| 38 |
+
|
| 39 |
+
def normalize_model_outputs(model_text):
|
| 40 |
+
"""post processing function of memo writing task"""
|
| 41 |
+
extracted_elements = [re.sub(r'\s+', ' ', mt.replace('"', '').replace("'", "")) for mt in re.findall(r"'[^']*'|\"[^\"]*\"|\d+", model_text)]
|
| 42 |
+
model_outputs = []
|
| 43 |
+
ti = 0
|
| 44 |
+
while ti + 7 < len(extracted_elements):
|
| 45 |
+
if extracted_elements[ti] == "topic" and extracted_elements[ti + 2] == "summary" and extracted_elements[ti + 4] == "start" and extracted_elements[ti + 6] == "end":
|
| 46 |
+
try:
|
| 47 |
+
model_outputs.append({"topic": extracted_elements[ti + 1], "summary": extracted_elements[ti + 3], "start": int(extracted_elements[ti + 5]), "end": int(extracted_elements[ti + 7])})
|
| 48 |
+
except:
|
| 49 |
+
pass
|
| 50 |
+
ti += 1
|
| 51 |
+
return model_outputs
|
| 52 |
+
|
| 53 |
+
def normalize_chatting_outputs(model_outputs):
|
| 54 |
+
"""post processing function of chatting response"""
|
| 55 |
+
def white_space_fix(text):
|
| 56 |
+
lines = text.split("\n")
|
| 57 |
+
result = []
|
| 58 |
+
for line in lines:
|
| 59 |
+
result.append(' '.join(line.split()))
|
| 60 |
+
output = '\n'.join(result)
|
| 61 |
+
return output
|
| 62 |
+
return white_space_fix(model_outputs)
|
| 63 |
+
|
| 64 |
+
def gen_model_output(model_path, model, tokenizer, input_qs, local_check, task_type):
|
| 65 |
+
if local_check:
|
| 66 |
+
from faker import Faker
|
| 67 |
+
fake = Faker(locale="en")
|
| 68 |
+
return fake.text(2000)
|
| 69 |
+
if "writing" in task_type:
|
| 70 |
+
eos_token_ids = [tokenizer.eos_token_id, tokenizer.encode("]", add_special_tokens=False)[0]]
|
| 71 |
+
elif "retrieval" in task_type:
|
| 72 |
+
eos_token_ids = [tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[0], tokenizer.encode(" ", add_special_tokens=False)[0]]
|
| 73 |
+
else:
|
| 74 |
+
eos_token_ids = [tokenizer.eos_token_id]
|
| 75 |
+
if "t5" in model_path:
|
| 76 |
+
# t5 model may need larger repetition penalty value to help get generation stability
|
| 77 |
+
input_ids = tokenizer([input_qs], max_length=MaxLen, truncation=True, add_special_tokens=False).input_ids
|
| 78 |
+
target_len = TaskTarLen[task_type]
|
| 79 |
+
repetition_penalty_value = 1.0
|
| 80 |
+
else:
|
| 81 |
+
input_ids = tokenizer([input_qs], max_length=(MaxLen - TarLen), truncation=True, add_special_tokens=False).input_ids
|
| 82 |
+
target_len = min(len(input_ids[0]) + TaskTarLen[task_type], MaxLen)
|
| 83 |
+
repetition_penalty_value = 1.0
|
| 84 |
+
output_ids = model.generate(
|
| 85 |
+
torch.as_tensor(input_ids).cuda(),
|
| 86 |
+
do_sample=True,
|
| 87 |
+
temperature=0.2,
|
| 88 |
+
max_length=target_len,
|
| 89 |
+
eos_token_id=eos_token_ids,
|
| 90 |
+
repetition_penalty=repetition_penalty_value
|
| 91 |
+
)
|
| 92 |
+
if "t5" in model_path:
|
| 93 |
+
output_ids = output_ids[0]
|
| 94 |
+
else:
|
| 95 |
+
output_ids = output_ids[0][len(input_ids[0]):]
|
| 96 |
+
model_outputs = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
|
| 97 |
+
return model_outputs
|
| 98 |
+
|
| 99 |
+
def run_summary(history, model_path, model, tokenizer, memo, local_check, bot_thinking, prompts):
|
| 100 |
+
"""We assume there's no too long input from user, e.g. over 1500 tokens"""
|
| 101 |
+
system_insturction = prompts["writing_dialogsum"]["system"]
|
| 102 |
+
task_instruction = prompts["writing_dialogsum"]["instruction"]
|
| 103 |
+
history_log = "\n\n```\nTask Conversation:\n" + "\n".join(["(line {}) {}".format(h_i + 1, h.replace("\n", " ")) for h_i, h in enumerate(history["Recent Dialogs"][2:])])
|
| 104 |
+
qs = q_pre + system_insturction.replace("LINE", str(len(history["Recent Dialogs"]) - 2)) + history_log + "\n```" + task_instruction.replace("LINE", str(len(history["Recent Dialogs"]) - 2)) + qa_link
|
| 105 |
+
print("#" * 20 + "summarizing" + "#" * 20)
|
| 106 |
+
print(qs)
|
| 107 |
+
print("#" * 20 + "summarizing" + "#" * 20)
|
| 108 |
+
sum_history = gen_model_output(model_path, model, tokenizer, qs, local_check, "writing_dialogsum")
|
| 109 |
+
sum_history = normalize_model_outputs(sum_history)
|
| 110 |
+
print("#" * 20 + "summarization" + "#" * 20)
|
| 111 |
+
print(sum_history)
|
| 112 |
+
print("#" * 20 + "summarization" + "#" * 20)
|
| 113 |
+
for s in sum_history:
|
| 114 |
+
memo[s["topic"]] = memo.get(s["topic"], []) + [{"summary": s["summary"], "dialogs": history["Recent Dialogs"][2:][(s["start"] - 1):s["end"]]}]
|
| 115 |
+
if local_check:
|
| 116 |
+
memo["test_topic{}".format(len(memo.keys()))] = [{"summary": "test_summary{}".format(len(memo.keys())), "dialogs": history["Recent Dialogs"][2:][2:4]}]
|
| 117 |
+
if len(sum_history) == 0:
|
| 118 |
+
si_0, si_1 = sample(list(range(len(history["Recent Dialogs"][2:]))), 2)
|
| 119 |
+
memo["NOTO"].append({"summary": "Partial dialogs about: {} or {}.".format(history["Recent Dialogs"][2:][si_0], history["Recent Dialogs"][2:][si_1]), "dialogs": history["Recent Dialogs"][2:]})
|
| 120 |
+
history["Recent Dialogs"] = history["Recent Dialogs"][-2:]
|
| 121 |
+
bot_thinking["summarization"] = {"input": qs, "output": sum_history}
|
| 122 |
+
return history, memo, bot_thinking
|
| 123 |
+
|
| 124 |
+
def run_retrieval(history, model_path, model, tokenizer, memo, local_check, bot_thinking, prompts):
|
| 125 |
+
topics = []
|
| 126 |
+
for k, v in memo.items():
|
| 127 |
+
for vv in v:
|
| 128 |
+
topics.append((k, vv["summary"], vv["dialogs"]))
|
| 129 |
+
system_insturction = prompts["retrieval"]["system"]
|
| 130 |
+
task_instruction = prompts["retrieval"]["instruction"]
|
| 131 |
+
task_case = "```\nQuery Sentence:\n" + history["User Input"][6:] + "\nTopic Options:\n" + \
|
| 132 |
+
"\n".join(["({}) {}".format(v_i + 1, v[0] + ". " + v[1]) for v_i, v in enumerate(topics)]) + "\n```"
|
| 133 |
+
qs = q_pre + system_insturction.replace("OPTION", str(len(topics))) + task_case + task_instruction.replace("OPTION", str(len(topics))) + qa_link
|
| 134 |
+
print("#" * 20 + "retrieving" + "#" * 20)
|
| 135 |
+
print(qs)
|
| 136 |
+
print("#" * 20 + "retrieving" + "#" * 20)
|
| 137 |
+
outputs = gen_model_output(model_path, model, tokenizer, qs, local_check, "retrieval_dialogsum")
|
| 138 |
+
print("#" * 20 + "retrieval" + "#" * 20)
|
| 139 |
+
print(outputs)
|
| 140 |
+
print("#" * 20 + "retrieval" + "#" * 20)
|
| 141 |
+
outputs = outputs.split("#")
|
| 142 |
+
chosen_topics = []
|
| 143 |
+
for output in outputs:
|
| 144 |
+
try:
|
| 145 |
+
index_ = int(output) - 1
|
| 146 |
+
except:
|
| 147 |
+
continue
|
| 148 |
+
if index_ < len(topics) and "NOTO" not in topics[index_][0]:
|
| 149 |
+
chosen_topics.append(topics[index_])
|
| 150 |
+
if local_check:
|
| 151 |
+
chosen_topics = sample(topics, min(len(topics) - 1, 2))
|
| 152 |
+
if len(chosen_topics) > 0:
|
| 153 |
+
history["Related Topics"] = [ct[0] for ct in chosen_topics]
|
| 154 |
+
history["Related Summaries"] = [ct[1] for ct in chosen_topics]
|
| 155 |
+
history["Related Dialogs"] = [" ### ".join(ct[2]) for ct in chosen_topics]
|
| 156 |
+
else:
|
| 157 |
+
history["Related Topics"] = []
|
| 158 |
+
history["Related Summaries"] = []
|
| 159 |
+
history["Related Dialogs"] = []
|
| 160 |
+
bot_thinking["retrieval"] = {"input": qs, "output": outputs}
|
| 161 |
+
return history, bot_thinking
|
| 162 |
+
|
| 163 |
+
@torch.inference_mode()
|
| 164 |
+
def get_model_answers(model_path, num_gpus, local_check, load_in_8bit, ques_jsons, prompts):
|
| 165 |
+
model_path = os.path.expanduser(model_path)
|
| 166 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, truncation_side='left')
|
| 167 |
+
|
| 168 |
+
if not local_check:
|
| 169 |
+
# We assume you have enough GPUs to load one model, but you can modify the gpu_memory_dict to allow CPU offloads
|
| 170 |
+
# We recommend to use as less as GPUs possible to load one model for faster inference, such as 3 V100 32G or 2 A100 40G for 33B model
|
| 171 |
+
available_gpu_memory = get_gpu_memory(num_gpus)
|
| 172 |
+
gpu_memory_dict = {i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" for i in range(num_gpus)}
|
| 173 |
+
gpu_memory_dict["cpu"] = "0GiB"
|
| 174 |
+
|
| 175 |
+
if "t5" in model_path:
|
| 176 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 177 |
+
model_path, torch_dtype=torch.float16, device_map="auto", max_memory=gpu_memory_dict, load_in_8bit=load_in_8bit
|
| 178 |
+
)
|
| 179 |
+
else:
|
| 180 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 181 |
+
model_path, torch_dtype=torch.float16, device_map="auto", max_memory=gpu_memory_dict, load_in_8bit=load_in_8bit
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Initialize with BetterTransformer, injecting Flash-Attention
|
| 185 |
+
model = BetterTransformer.transform(model)
|
| 186 |
+
|
| 187 |
+
# turn on eval mode to stop batch normalizarion & dropout, can work together with torch.inference_mode
|
| 188 |
+
model = model.eval()
|
| 189 |
+
else:
|
| 190 |
+
model = None
|
| 191 |
+
|
| 192 |
+
output_data = []
|
| 193 |
+
for d in ques_jsons:
|
| 194 |
+
new_d = d
|
| 195 |
+
|
| 196 |
+
history = {
|
| 197 |
+
"Recent Dialogs": ["user: Hi!", "bot: Hi! How can I help you today?"],
|
| 198 |
+
"Related Topics": [],
|
| 199 |
+
"Related Summaries": [],
|
| 200 |
+
"Related Dialogs": [],
|
| 201 |
+
"User Input": "",
|
| 202 |
+
}
|
| 203 |
+
memo = {
|
| 204 |
+
"NOTO": [{"summary": "None of the others.", "dialogs": []}]
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
for l_i in range(len(new_d["conversations"])):
|
| 208 |
+
if l_i % 2 == 1:
|
| 209 |
+
bot_thinking = {"retrieval": "", "summarization": ""}
|
| 210 |
+
print("=" * 20 + "start of turn {}".format(l_i // 2 + 1) + "=" * 20)
|
| 211 |
+
user = "user: " + new_d["conversations"][l_i - 1]["value"]
|
| 212 |
+
print(user + "\n\n")
|
| 213 |
+
|
| 214 |
+
# create summary if recent dialogs exceed threshold
|
| 215 |
+
if len(" ### ".join(history["Recent Dialogs"]).split(" ")) > (MaxLen // 2) or len(history["Recent Dialogs"]) >= 10:
|
| 216 |
+
history, memo, bot_thinking = run_summary(history, model_path, model, tokenizer, memo, local_check, bot_thinking, prompts)
|
| 217 |
+
|
| 218 |
+
# retrieve most related topics for every new user input
|
| 219 |
+
history["User Input"] = user
|
| 220 |
+
if len(memo.keys()) > 1:
|
| 221 |
+
history, bot_thinking = run_retrieval(history, model_path, model, tokenizer, memo, local_check, bot_thinking, prompts)
|
| 222 |
+
|
| 223 |
+
# generate bot response
|
| 224 |
+
system_insturction = prompts["chatting"]["system"]
|
| 225 |
+
task_instruction = prompts["chatting"]["instruction"]
|
| 226 |
+
task_case = "```\nRelated Evidences:\n" + "\n".join(["({}) {}".format(r_tsd_i + 1, {
|
| 227 |
+
"Related Topics": history["Related Topics"][r_tsd_i],
|
| 228 |
+
"Related Summaries": history["Related Summaries"][r_tsd_i],
|
| 229 |
+
"Related Dialogs": history["Related Dialogs"][r_tsd_i]
|
| 230 |
+
}) for r_tsd_i in range(len(history["Related Topics"]))]) + "\n\nRecent Dialogs:\n" + \
|
| 231 |
+
" ### ".join([hrd.replace("\n", " ") for hrd in history["Recent Dialogs"]]) + "\n```\n\nUser Input:\n" + history["User Input"] + " ### bot: "
|
| 232 |
+
qs = q_pre + system_insturction + task_case + task_instruction + qa_link
|
| 233 |
+
outputs = gen_model_output(model_path, model, tokenizer, qs, local_check, "chatting_dialogsum")
|
| 234 |
+
outputs = normalize_chatting_outputs(outputs)
|
| 235 |
+
history["Recent Dialogs"] += [user, "bot: " + outputs]
|
| 236 |
+
print("bot: " + outputs + "\n")
|
| 237 |
+
print("=" * 20 + "end of turn {}".format(l_i // 2 + 1) + "=" * 20)
|
| 238 |
+
print("\n\n\n\n")
|
| 239 |
+
new_d["conversations"][l_i]["thinking"] = json.dumps(bot_thinking)
|
| 240 |
+
new_d["conversations"][l_i]["value"] = outputs
|
| 241 |
+
|
| 242 |
+
output_data.append(d)
|
| 243 |
+
return output_data
|
| 244 |
+
|
| 245 |
+
def run_eval(model_path, num_gpus, local_check, load_in_8bit, question_file, ray_num_gpus, answer_file, prompt_path):
|
| 246 |
+
assert num_gpus % ray_num_gpus == 0
|
| 247 |
+
prompts = json.load(open(prompt_path, "r"))
|
| 248 |
+
|
| 249 |
+
# split question file into num_gpus files
|
| 250 |
+
ques_jsons = json.load(open(os.path.expanduser(question_file), "r"))
|
| 251 |
+
|
| 252 |
+
chunk_size = len(ques_jsons) // (num_gpus // ray_num_gpus)
|
| 253 |
+
ans_handles = []
|
| 254 |
+
for i in range(0, len(ques_jsons), chunk_size):
|
| 255 |
+
get_answers_func = ray.remote(num_gpus=ray_num_gpus)(
|
| 256 |
+
get_model_answers
|
| 257 |
+
).remote
|
| 258 |
+
ans_handles.append(
|
| 259 |
+
get_answers_func(
|
| 260 |
+
model_path, ray_num_gpus, local_check, load_in_8bit, ques_jsons[i: i + chunk_size], prompts
|
| 261 |
+
)
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
ans_jsons = []
|
| 265 |
+
for ans_handle in ans_handles:
|
| 266 |
+
ans_jsons.extend(ray.get(ans_handle))
|
| 267 |
+
|
| 268 |
+
json.dump(ans_jsons, open(os.path.expanduser(answer_file), "w"), indent=2)
|
| 269 |
+
|
| 270 |
+
if __name__ == "__main__":
|
| 271 |
+
parser = argparse.ArgumentParser()
|
| 272 |
+
parser.add_argument("--model-path", type=str, required=True)
|
| 273 |
+
parser.add_argument("--num-gpus", type=int, required=True)
|
| 274 |
+
parser.add_argument("--ray-num-gpus", type=int, required=True)
|
| 275 |
+
parser.add_argument("--local-check", action="store_true", help="use faker to generate fake resposne, used for pipeline prompt checking")
|
| 276 |
+
parser.add_argument("--load-in-8bit", action="store_true")
|
| 277 |
+
parser.add_argument("--question-file", type=str, required=True)
|
| 278 |
+
parser.add_argument("--answer-file", type=str, required=True)
|
| 279 |
+
parser.add_argument("--prompt-path", type=str, required=True)
|
| 280 |
+
args = parser.parse_args()
|
| 281 |
+
|
| 282 |
+
run_eval(
|
| 283 |
+
args.model_path,
|
| 284 |
+
args.num_gpus,
|
| 285 |
+
args.local_check,
|
| 286 |
+
args.load_in_8bit,
|
| 287 |
+
args.question_file,
|
| 288 |
+
args.ray_num_gpus,
|
| 289 |
+
args.answer_file,
|
| 290 |
+
args.prompt_path
|
| 291 |
+
)
|
baselines/MemoChat/code/codes/eval/get_model_infer_simple.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
| 3 |
+
from optimum.bettertransformer import BetterTransformer
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import ray
|
| 9 |
+
|
| 10 |
+
q_pre = "<s>\n"
|
| 11 |
+
qa_link = "\n"
|
| 12 |
+
a_pos = "\n</s>"
|
| 13 |
+
MaxLen = 2048
|
| 14 |
+
TarLen = 512
|
| 15 |
+
TaskTarLen = {
|
| 16 |
+
"chatting_dialogsum": MaxLen,
|
| 17 |
+
"chatting_alpacagpt4": MaxLen,
|
| 18 |
+
"writing_topiocqa": TarLen // 2,
|
| 19 |
+
"writing_dialogsum": TarLen,
|
| 20 |
+
"retrieval_dialogsum": 32,
|
| 21 |
+
"retrieval_topiocqa": 32
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
def get_gpu_memory(ray_num_gpus):
|
| 25 |
+
"""Get available memory for each GPU."""
|
| 26 |
+
gpu_memory = []
|
| 27 |
+
for gpu_id in range(ray_num_gpus):
|
| 28 |
+
with torch.cuda.device(gpu_id):
|
| 29 |
+
device = torch.cuda.current_device()
|
| 30 |
+
gpu_properties = torch.cuda.get_device_properties(device)
|
| 31 |
+
total_memory = gpu_properties.total_memory / (1024**3)
|
| 32 |
+
allocated_memory = torch.cuda.memory_allocated() / (1024**3)
|
| 33 |
+
available_memory = total_memory - allocated_memory
|
| 34 |
+
gpu_memory.append(available_memory)
|
| 35 |
+
return gpu_memory
|
| 36 |
+
|
| 37 |
+
def run_eval(model_path, model_id, question_file, answer_file, num_gpus, load_in_8bit, ray_num_gpus):
|
| 38 |
+
assert num_gpus % ray_num_gpus == 0
|
| 39 |
+
|
| 40 |
+
# split question file into num_gpus files
|
| 41 |
+
ques_jsons = []
|
| 42 |
+
with open(os.path.expanduser(question_file), "r") as ques_file:
|
| 43 |
+
for line in ques_file:
|
| 44 |
+
ques_jsons.append(line)
|
| 45 |
+
|
| 46 |
+
chunk_size = len(ques_jsons) // (num_gpus // ray_num_gpus)
|
| 47 |
+
ans_handles = []
|
| 48 |
+
for i in range(0, len(ques_jsons), chunk_size):
|
| 49 |
+
get_answers_func = ray.remote(num_gpus=ray_num_gpus)(
|
| 50 |
+
get_model_answers
|
| 51 |
+
).remote
|
| 52 |
+
ans_handles.append(
|
| 53 |
+
get_answers_func(
|
| 54 |
+
model_path, model_id, ques_jsons[i: i + chunk_size], ray_num_gpus, load_in_8bit
|
| 55 |
+
)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
ans_jsons = []
|
| 59 |
+
for ans_handle in ans_handles:
|
| 60 |
+
ans_jsons.extend(ray.get(ans_handle))
|
| 61 |
+
|
| 62 |
+
with open(os.path.expanduser(answer_file), "w") as ans_file:
|
| 63 |
+
for line in ans_jsons:
|
| 64 |
+
ans_file.write(json.dumps(line) + "\n")
|
| 65 |
+
|
| 66 |
+
@torch.inference_mode()
|
| 67 |
+
def get_model_answers(model_path, model_id, question_jsons, ray_num_gpus, load_in_8bit):
|
| 68 |
+
model_path = os.path.expanduser(model_path)
|
| 69 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, truncation_side='left')
|
| 70 |
+
|
| 71 |
+
available_gpu_memory = get_gpu_memory(ray_num_gpus)
|
| 72 |
+
gpu_memory_dict = {i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" for i in range(ray_num_gpus)}
|
| 73 |
+
gpu_memory_dict["cpu"] = "0GiB"
|
| 74 |
+
|
| 75 |
+
if "t5" in model_path:
|
| 76 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 77 |
+
model_path, torch_dtype=torch.float16, device_map="auto", max_memory=gpu_memory_dict, load_in_8bit=load_in_8bit
|
| 78 |
+
)
|
| 79 |
+
else:
|
| 80 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 81 |
+
model_path, torch_dtype=torch.float16, device_map="auto", max_memory=gpu_memory_dict, load_in_8bit=load_in_8bit
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Initialize with BetterTransformer, injecting Flash-Attention
|
| 85 |
+
model = BetterTransformer.transform(model)
|
| 86 |
+
|
| 87 |
+
# turn on eval mode to stop batch normalizarion & dropout, can work together with torch.inference_mode
|
| 88 |
+
model = model.eval()
|
| 89 |
+
|
| 90 |
+
ans_jsons = []
|
| 91 |
+
for i, line in enumerate(tqdm(question_jsons)):
|
| 92 |
+
ques_json = json.loads(line)
|
| 93 |
+
idx = ques_json["question_id"]
|
| 94 |
+
qs = q_pre + ques_json["text"] + qa_link
|
| 95 |
+
|
| 96 |
+
task_type = ques_json["type"]
|
| 97 |
+
if "t5" in model_path:
|
| 98 |
+
input_ids = tokenizer([qs], max_length=MaxLen, truncation=True, add_special_tokens=False).input_ids
|
| 99 |
+
target_len = TaskTarLen[task_type]
|
| 100 |
+
else:
|
| 101 |
+
input_ids = tokenizer([qs], max_length=(MaxLen - TarLen), truncation=True, add_special_tokens=False).input_ids
|
| 102 |
+
target_len = min(len(input_ids[0]) + TaskTarLen[task_type], MaxLen)
|
| 103 |
+
|
| 104 |
+
output_ids = model.generate(
|
| 105 |
+
torch.as_tensor(input_ids).cuda(),
|
| 106 |
+
do_sample=True,
|
| 107 |
+
temperature=0.2,
|
| 108 |
+
max_length=target_len
|
| 109 |
+
)
|
| 110 |
+
if "t5" in model_path:
|
| 111 |
+
output_ids = output_ids[0]
|
| 112 |
+
else:
|
| 113 |
+
output_ids = output_ids[0][len(input_ids[0]):]
|
| 114 |
+
outputs = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
|
| 115 |
+
|
| 116 |
+
print(outputs)
|
| 117 |
+
|
| 118 |
+
ans_jsons.append(
|
| 119 |
+
{
|
| 120 |
+
"question_id": idx,
|
| 121 |
+
"text": outputs,
|
| 122 |
+
"model_id": model_id,
|
| 123 |
+
"metadata": {},
|
| 124 |
+
}
|
| 125 |
+
)
|
| 126 |
+
return ans_jsons
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
parser = argparse.ArgumentParser()
|
| 130 |
+
parser.add_argument("--model-path", type=str, required=True)
|
| 131 |
+
parser.add_argument("--model-id", type=str, required=True)
|
| 132 |
+
parser.add_argument("--question-file", type=str, required=True)
|
| 133 |
+
parser.add_argument("--answer-file", type=str, default="answer.jsonl")
|
| 134 |
+
parser.add_argument("--num-gpus", type=int, default=1)
|
| 135 |
+
parser.add_argument("--ray-num-gpus", type=int, default=1)
|
| 136 |
+
parser.add_argument("--load-in-8bit", action="store_true")
|
| 137 |
+
args = parser.parse_args()
|
| 138 |
+
|
| 139 |
+
os.environ["RAY_DEDUP_LOGS"] = "0"
|
| 140 |
+
ray.init(num_gpus=args.num_gpus)
|
| 141 |
+
|
| 142 |
+
run_eval(
|
| 143 |
+
args.model_path,
|
| 144 |
+
args.model_id,
|
| 145 |
+
args.question_file,
|
| 146 |
+
args.answer_file,
|
| 147 |
+
args.num_gpus,
|
| 148 |
+
args.load_in_8bit,
|
| 149 |
+
args.ray_num_gpus
|
| 150 |
+
)
|
baselines/MemoChat/code/codes/train/data_preprocess.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import sys
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Optional
|
| 5 |
+
import torch
|
| 6 |
+
import copy
|
| 7 |
+
|
| 8 |
+
import datasets
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
|
| 11 |
+
import transformers
|
| 12 |
+
from transformers import (
|
| 13 |
+
HfArgumentParser,
|
| 14 |
+
T5Tokenizer,
|
| 15 |
+
LlamaTokenizer,
|
| 16 |
+
set_seed,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
q_pre = "<s>\n"
|
| 20 |
+
qa_link = "\n"
|
| 21 |
+
a_pos = "\n</s>"
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class ModelArguments:
|
| 27 |
+
model_name_or_path: str = field(
|
| 28 |
+
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class DataTrainingArguments:
|
| 33 |
+
data_path: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
|
| 34 |
+
model_max_length: int = field(
|
| 35 |
+
default=2048,
|
| 36 |
+
metadata={
|
| 37 |
+
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
| 38 |
+
},
|
| 39 |
+
)
|
| 40 |
+
preprocessed_path: str = field(
|
| 41 |
+
default=None, metadata={"help": "Path to the preprocessed training data."}
|
| 42 |
+
)
|
| 43 |
+
preprocessing_num_workers: Optional[int] = field(
|
| 44 |
+
default=None,
|
| 45 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def main():
|
| 49 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments))
|
| 50 |
+
model_args, data_args = parser.parse_args_into_dataclasses()
|
| 51 |
+
|
| 52 |
+
data_files = {}
|
| 53 |
+
data_files["train"] = data_args.data_path
|
| 54 |
+
raw_datasets = load_dataset(
|
| 55 |
+
"json",
|
| 56 |
+
data_files=data_files
|
| 57 |
+
)
|
| 58 |
+
column_names = raw_datasets["train"].column_names
|
| 59 |
+
print("load dataset finished")
|
| 60 |
+
|
| 61 |
+
if "t5" in model_args.model_name_or_path:
|
| 62 |
+
# use truncation_side='left' to preserve linking between end of prompt and target labels
|
| 63 |
+
tokenizer = T5Tokenizer.from_pretrained(model_args.model_name_or_path, truncation_side='left')
|
| 64 |
+
|
| 65 |
+
def preprocess_function(examples):
|
| 66 |
+
src_inputs = [q_pre + example[0]["value"] + qa_link for example in examples["conversations"]]
|
| 67 |
+
src_model_inputs = tokenizer(src_inputs, max_length=data_args.model_max_length, padding='longest', truncation=True, add_special_tokens=False)
|
| 68 |
+
trg_inputs = [example[1]["value"] + a_pos for example in examples["conversations"]]
|
| 69 |
+
trg_model_inputs = tokenizer(trg_inputs, max_length=data_args.model_max_length, padding='longest', truncation=True, add_special_tokens=False)
|
| 70 |
+
src_model_inputs["labels"] = [
|
| 71 |
+
[(l if l != tokenizer.pad_token_id else label_ignore_id) for l in label] for label in trg_model_inputs["input_ids"]
|
| 72 |
+
]
|
| 73 |
+
return src_model_inputs
|
| 74 |
+
else:
|
| 75 |
+
# use truncation_side='left' to preserve linking between end of prompt and target labels
|
| 76 |
+
tokenizer = LlamaTokenizer.from_pretrained(model_args.model_name_or_path, truncation_side='left')
|
| 77 |
+
|
| 78 |
+
def preprocess_function(examples):
|
| 79 |
+
inputs = [q_pre + example[0]["value"] + qa_link + example[1]["value"] + a_pos for example in examples["conversations"]]
|
| 80 |
+
model_inputs = tokenizer(inputs, max_length=data_args.model_max_length, padding="longest", truncation=True, add_special_tokens=False)
|
| 81 |
+
model_inputs["labels"] = copy.deepcopy(model_inputs["input_ids"])
|
| 82 |
+
for e_i, example in enumerate(examples["conversations"]):
|
| 83 |
+
source_text = q_pre + example[0]["value"] + qa_link
|
| 84 |
+
target_text = example[1]["value"] + a_pos
|
| 85 |
+
source_ids = tokenizer.encode(source_text, add_special_tokens=False)
|
| 86 |
+
target_ids = tokenizer.encode(target_text, add_special_tokens=False)
|
| 87 |
+
if len(source_ids) >= data_args.model_max_length:
|
| 88 |
+
model_inputs["labels"][e_i] = [label_ignore_id] * data_args.model_max_length
|
| 89 |
+
continue
|
| 90 |
+
else:
|
| 91 |
+
model_inputs["labels"][e_i][:len(source_ids)] = [label_ignore_id] * len(source_ids)
|
| 92 |
+
if len(target_ids) + len(source_ids) >= len(model_inputs["input_ids"][e_i]):
|
| 93 |
+
continue
|
| 94 |
+
else:
|
| 95 |
+
model_inputs["labels"][e_i][(len(target_ids) + len(source_ids)):] = [label_ignore_id] * (len(model_inputs["input_ids"][e_i]) - len(target_ids) - len(source_ids))
|
| 96 |
+
model_inputs["input_ids"] = torch.tensor(model_inputs["input_ids"])
|
| 97 |
+
model_inputs["labels"] = torch.tensor(model_inputs["labels"])
|
| 98 |
+
model_inputs["attention_mask"] = model_inputs["input_ids"].ne(tokenizer.pad_token_id)
|
| 99 |
+
return model_inputs
|
| 100 |
+
|
| 101 |
+
label_ignore_id = -100
|
| 102 |
+
|
| 103 |
+
print("start data preprocess")
|
| 104 |
+
train_dataset = raw_datasets["train"]
|
| 105 |
+
train_dataset = train_dataset.map(
|
| 106 |
+
preprocess_function,
|
| 107 |
+
batched=True,
|
| 108 |
+
batch_size=len(train_dataset),
|
| 109 |
+
remove_columns=column_names,
|
| 110 |
+
num_proc=data_args.preprocessing_num_workers,
|
| 111 |
+
load_from_cache_file=False,
|
| 112 |
+
desc="Running tokenizer on train dataset"
|
| 113 |
+
)
|
| 114 |
+
train_dataset.save_to_disk(data_args.preprocessed_path)
|
| 115 |
+
print("data preprocess finished")
|
| 116 |
+
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
main()
|
baselines/MemoChat/code/codes/train/train.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import sys
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Optional
|
| 5 |
+
import torch
|
| 6 |
+
import copy
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
import datasets
|
| 10 |
+
from datasets import load_from_disk
|
| 11 |
+
|
| 12 |
+
import transformers
|
| 13 |
+
from transformers import (
|
| 14 |
+
HfArgumentParser,
|
| 15 |
+
T5ForConditionalGeneration,
|
| 16 |
+
T5Tokenizer,
|
| 17 |
+
T5Config,
|
| 18 |
+
LlamaForCausalLM,
|
| 19 |
+
LlamaTokenizer,
|
| 20 |
+
LlamaConfig,
|
| 21 |
+
Trainer,
|
| 22 |
+
TrainingArguments,
|
| 23 |
+
set_seed,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
from optimum.bettertransformer import BetterTransformer
|
| 27 |
+
|
| 28 |
+
q_pre = "<s>\n"
|
| 29 |
+
qa_link = "\n"
|
| 30 |
+
a_pos = "\n</s>"
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class ModelArguments:
|
| 36 |
+
model_name_or_path: str = field(
|
| 37 |
+
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class DataTrainingArguments:
|
| 42 |
+
model_max_length: int = field(
|
| 43 |
+
default=2048,
|
| 44 |
+
metadata={
|
| 45 |
+
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
| 46 |
+
},
|
| 47 |
+
)
|
| 48 |
+
max_train_samples: Optional[int] = field(
|
| 49 |
+
default=None,
|
| 50 |
+
metadata={
|
| 51 |
+
"help": (
|
| 52 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 53 |
+
"value if set."
|
| 54 |
+
)
|
| 55 |
+
},
|
| 56 |
+
)
|
| 57 |
+
preprocessed_path: str = field(
|
| 58 |
+
default=None, metadata={"help": "Path to the preprocessed training data."}
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def main():
|
| 62 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
| 63 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 64 |
+
|
| 65 |
+
# Setup logging
|
| 66 |
+
logging.basicConfig(
|
| 67 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 68 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 69 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if training_args.should_log:
|
| 73 |
+
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
| 74 |
+
transformers.utils.logging.set_verbosity_info()
|
| 75 |
+
|
| 76 |
+
log_level = training_args.get_process_log_level()
|
| 77 |
+
logger.setLevel(log_level)
|
| 78 |
+
datasets.utils.logging.set_verbosity(log_level)
|
| 79 |
+
transformers.utils.logging.set_verbosity(log_level)
|
| 80 |
+
transformers.utils.logging.enable_default_handler()
|
| 81 |
+
transformers.utils.logging.enable_explicit_format()
|
| 82 |
+
# Log on each process the small summary:
|
| 83 |
+
logger.warning(
|
| 84 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
| 85 |
+
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}"
|
| 86 |
+
)
|
| 87 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
| 88 |
+
|
| 89 |
+
logger.info("start to load dataset")
|
| 90 |
+
train_dataset = load_from_disk(data_args.preprocessed_path)
|
| 91 |
+
column_names = train_dataset.column_names
|
| 92 |
+
if data_args.max_train_samples is not None:
|
| 93 |
+
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
| 94 |
+
train_dataset = train_dataset.select(range(max_train_samples))
|
| 95 |
+
logger.info("load dataset finished")
|
| 96 |
+
|
| 97 |
+
if "t5" in model_args.model_name_or_path:
|
| 98 |
+
# load config and tokenziers
|
| 99 |
+
config = T5Config.from_pretrained(model_args.model_name_or_path)
|
| 100 |
+
config.use_cache=False
|
| 101 |
+
# use truncation_side='left' to preserve linking between end of prompt and target labels
|
| 102 |
+
tokenizer = T5Tokenizer.from_pretrained(model_args.model_name_or_path, truncation_side='left')
|
| 103 |
+
model = T5ForConditionalGeneration.from_pretrained(model_args.model_name_or_path, config=config)
|
| 104 |
+
else:
|
| 105 |
+
# load config and tokenziers
|
| 106 |
+
config = LlamaConfig.from_pretrained(model_args.model_name_or_path)
|
| 107 |
+
config.use_cache=False
|
| 108 |
+
# use truncation_side='left' to preserve linking between end of prompt and target labels
|
| 109 |
+
tokenizer = LlamaTokenizer.from_pretrained(model_args.model_name_or_path, truncation_side='left')
|
| 110 |
+
# initialize modules
|
| 111 |
+
model = LlamaForCausalLM.from_pretrained(model_args.model_name_or_path, config=config)
|
| 112 |
+
|
| 113 |
+
# convert normal model to bettertransformer
|
| 114 |
+
model = BetterTransformer.transform(model)
|
| 115 |
+
|
| 116 |
+
# Setup seed
|
| 117 |
+
set_seed(training_args.seed)
|
| 118 |
+
if len(tokenizer) > tokenizer.vocab_size:
|
| 119 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 120 |
+
|
| 121 |
+
# Setup Trainer
|
| 122 |
+
trainer = Trainer(
|
| 123 |
+
model=model,
|
| 124 |
+
args=training_args,
|
| 125 |
+
train_dataset=train_dataset,
|
| 126 |
+
tokenizer=tokenizer
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Training
|
| 130 |
+
train_result = trainer.train()
|
| 131 |
+
|
| 132 |
+
# convert bettertransformer to normal model
|
| 133 |
+
trainer.model = BetterTransformer.reverse(trainer.model)
|
| 134 |
+
trainer.save_state()
|
| 135 |
+
|
| 136 |
+
# save fp16 model under deepspeed zero2 or zero3
|
| 137 |
+
c_stage = json.load(open(training_args.deepspeed, "r"))["zero_optimization"]["stage"]
|
| 138 |
+
if c_stage in [2, 3]:
|
| 139 |
+
if c_stage == 2:
|
| 140 |
+
w_state_dict = trainer.model.state_dict()
|
| 141 |
+
else:
|
| 142 |
+
w_state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
|
| 143 |
+
if trainer.is_world_process_zero():
|
| 144 |
+
state_dict = {key: value.half().cpu() for key, value in w_state_dict.items()}
|
| 145 |
+
trainer._save(training_args.output_dir, state_dict=state_dict)
|
| 146 |
+
else:
|
| 147 |
+
trainer.save_model()
|
| 148 |
+
|
| 149 |
+
if __name__ == "__main__":
|
| 150 |
+
main()
|
baselines/MemoChat/code/configs/ds_config_13b.json
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
"optimizer": {
|
| 14 |
+
"type": "AdamW",
|
| 15 |
+
"params": {
|
| 16 |
+
"lr": "auto",
|
| 17 |
+
"betas": "auto",
|
| 18 |
+
"eps": "auto",
|
| 19 |
+
"weight_decay": "auto"
|
| 20 |
+
}
|
| 21 |
+
},
|
| 22 |
+
"scheduler": {
|
| 23 |
+
"type": "WarmupDecayLR",
|
| 24 |
+
"params": {
|
| 25 |
+
"total_num_steps" : "auto",
|
| 26 |
+
"warmup_min_lr": "auto",
|
| 27 |
+
"warmup_max_lr": "auto",
|
| 28 |
+
"warmup_num_steps": "auto"
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
"zero_optimization": {
|
| 32 |
+
"stage": 3,
|
| 33 |
+
"offload_optimizer": {
|
| 34 |
+
"device": "cpu",
|
| 35 |
+
"pin_memory": true
|
| 36 |
+
},
|
| 37 |
+
"overlap_comm": true,
|
| 38 |
+
"contiguous_gradients": true,
|
| 39 |
+
"sub_group_size": 1e9,
|
| 40 |
+
"reduce_bucket_size": "auto",
|
| 41 |
+
"stage3_prefetch_bucket_size": "auto",
|
| 42 |
+
"stage3_param_persistence_threshold": "auto",
|
| 43 |
+
"stage3_max_live_parameters": 1e9,
|
| 44 |
+
"stage3_max_reuse_distance": 1e9,
|
| 45 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
| 46 |
+
},
|
| 47 |
+
"gradient_accumulation_steps": "auto",
|
| 48 |
+
"gradient_clipping": "auto",
|
| 49 |
+
"steps_per_print": 2000,
|
| 50 |
+
"train_batch_size": "auto",
|
| 51 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 52 |
+
"wall_clock_breakdown": false
|
| 53 |
+
}
|
baselines/MemoChat/code/configs/ds_config_33b.json
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
"optimizer": {
|
| 14 |
+
"type": "AdamW",
|
| 15 |
+
"params": {
|
| 16 |
+
"lr": "auto",
|
| 17 |
+
"betas": "auto",
|
| 18 |
+
"eps": "auto",
|
| 19 |
+
"weight_decay": "auto"
|
| 20 |
+
}
|
| 21 |
+
},
|
| 22 |
+
"scheduler": {
|
| 23 |
+
"type": "WarmupDecayLR",
|
| 24 |
+
"params": {
|
| 25 |
+
"total_num_steps" : "auto",
|
| 26 |
+
"warmup_min_lr": "auto",
|
| 27 |
+
"warmup_max_lr": "auto",
|
| 28 |
+
"warmup_num_steps": "auto"
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
"zero_optimization": {
|
| 32 |
+
"stage": 3,
|
| 33 |
+
"offload_optimizer": {
|
| 34 |
+
"device": "cpu",
|
| 35 |
+
"pin_memory": true
|
| 36 |
+
},
|
| 37 |
+
"offload_param": {
|
| 38 |
+
"device": "cpu",
|
| 39 |
+
"pin_memory": true
|
| 40 |
+
},
|
| 41 |
+
"overlap_comm": true,
|
| 42 |
+
"contiguous_gradients": true,
|
| 43 |
+
"sub_group_size": 1e9,
|
| 44 |
+
"reduce_bucket_size": "auto",
|
| 45 |
+
"stage3_prefetch_bucket_size": "auto",
|
| 46 |
+
"stage3_param_persistence_threshold": "auto",
|
| 47 |
+
"stage3_max_live_parameters": 1e9,
|
| 48 |
+
"stage3_max_reuse_distance": 1e9,
|
| 49 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
| 50 |
+
},
|
| 51 |
+
"gradient_accumulation_steps": "auto",
|
| 52 |
+
"gradient_clipping": "auto",
|
| 53 |
+
"steps_per_print": 2000,
|
| 54 |
+
"train_batch_size": "auto",
|
| 55 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 56 |
+
"wall_clock_breakdown": false
|
| 57 |
+
}
|
baselines/MemoChat/code/configs/ds_config_3b.json
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
"optimizer": {
|
| 14 |
+
"type": "AdamW",
|
| 15 |
+
"params": {
|
| 16 |
+
"lr": "auto",
|
| 17 |
+
"betas": "auto",
|
| 18 |
+
"eps": "auto",
|
| 19 |
+
"weight_decay": "auto"
|
| 20 |
+
}
|
| 21 |
+
},
|
| 22 |
+
"scheduler": {
|
| 23 |
+
"type": "WarmupLR",
|
| 24 |
+
"params": {
|
| 25 |
+
"warmup_min_lr": "auto",
|
| 26 |
+
"warmup_max_lr": "auto",
|
| 27 |
+
"warmup_num_steps": "auto"
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
"zero_optimization": {
|
| 31 |
+
"stage": 1
|
| 32 |
+
},
|
| 33 |
+
"gradient_accumulation_steps": "auto",
|
| 34 |
+
"gradient_clipping": "auto",
|
| 35 |
+
"steps_per_print": 2000,
|
| 36 |
+
"train_batch_size": "auto",
|
| 37 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 38 |
+
"wall_clock_breakdown": false
|
| 39 |
+
}
|
baselines/MemoChat/code/configs/ds_config_7b.json
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
"optimizer": {
|
| 14 |
+
"type": "AdamW",
|
| 15 |
+
"params": {
|
| 16 |
+
"lr": "auto",
|
| 17 |
+
"betas": "auto",
|
| 18 |
+
"eps": "auto",
|
| 19 |
+
"weight_decay": "auto"
|
| 20 |
+
}
|
| 21 |
+
},
|
| 22 |
+
"scheduler": {
|
| 23 |
+
"type": "WarmupLR",
|
| 24 |
+
"params": {
|
| 25 |
+
"warmup_min_lr": "auto",
|
| 26 |
+
"warmup_max_lr": "auto",
|
| 27 |
+
"warmup_num_steps": "auto"
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
"zero_optimization": {
|
| 31 |
+
"stage": 2,
|
| 32 |
+
"offload_optimizer": {
|
| 33 |
+
"device": "cpu",
|
| 34 |
+
"pin_memory": true
|
| 35 |
+
},
|
| 36 |
+
"allgather_partitions": true,
|
| 37 |
+
"allgather_bucket_size": 2e8,
|
| 38 |
+
"overlap_comm": true,
|
| 39 |
+
"reduce_scatter": true,
|
| 40 |
+
"reduce_bucket_size": 2e8,
|
| 41 |
+
"contiguous_gradients": true
|
| 42 |
+
},
|
| 43 |
+
"gradient_accumulation_steps": "auto",
|
| 44 |
+
"gradient_clipping": "auto",
|
| 45 |
+
"steps_per_print": 2000,
|
| 46 |
+
"train_batch_size": "auto",
|
| 47 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 48 |
+
"wall_clock_breakdown": false
|
| 49 |
+
}
|
baselines/MemoChat/code/scripts/llm_judge.sh
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export GLOO_SOCKET_IFNAME=eth0
|
| 2 |
+
export WANDB_MODE=disabled
|
| 3 |
+
|
| 4 |
+
maindir=$1
|
| 5 |
+
datadir=${maindir}data
|
| 6 |
+
codedir=${maindir}code
|
| 7 |
+
|
| 8 |
+
settings=("10k")
|
| 9 |
+
models=("t5-3b" "vicuna-7b" "vicuna-13b" "vicuna-33b")
|
| 10 |
+
|
| 11 |
+
for model in "${models[@]}"
|
| 12 |
+
do
|
| 13 |
+
for setting in "${settings[@]}"
|
| 14 |
+
do
|
| 15 |
+
python3 ${codedir}/codes/api/llm_judge.py \
|
| 16 |
+
${datadir}/mtbenchplus/mtbenchplus_testing/mtbenchplus_testing_${model}_${setting}.json \
|
| 17 |
+
gpt-4 \
|
| 18 |
+
YourOpenAIKey \
|
| 19 |
+
${datadir}/llm_judge/llm_judge_gpt-4_${model}_${setting}.json \
|
| 20 |
+
${datadir}/prompts.json
|
| 21 |
+
done
|
| 22 |
+
done
|
| 23 |
+
|
| 24 |
+
gpt_settings=("2k" "memochat")
|
| 25 |
+
|
| 26 |
+
for gpt_setting in "${gpt_settings[@]}"
|
| 27 |
+
do
|
| 28 |
+
python3 ${codedir}/codes/api/llm_judge.py \
|
| 29 |
+
${datadir}/mtbenchplus/mtbenchplus_testing/mtbenchplus_testing_gpt-3.5-turbo-${gpt_setting}.json \
|
| 30 |
+
gpt-4 \
|
| 31 |
+
YourOpenAIKey \
|
| 32 |
+
${datadir}/llm_judge/llm_judge_gpt-4_gpt-3.5-turbo-${gpt_setting}.json \
|
| 33 |
+
${datadir}/prompts.json
|
| 34 |
+
done
|
| 35 |
+
|
baselines/MemoChat/code/scripts/memochat.sh
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export GLOO_SOCKET_IFNAME=eth0
|
| 2 |
+
export WANDB_MODE=disabled
|
| 3 |
+
|
| 4 |
+
maindir=$1
|
| 5 |
+
datadir=${maindir}data
|
| 6 |
+
codedir=${maindir}code
|
| 7 |
+
|
| 8 |
+
test_data=${datadir}/mtbenchplus/mtbenchplus.json
|
| 9 |
+
|
| 10 |
+
settings=("1k", "10k")
|
| 11 |
+
models=("t5-3b" "vicuna-7b" "vicuna-13b" "vicuna-33b")
|
| 12 |
+
|
| 13 |
+
for model in "${models[@]}"
|
| 14 |
+
do
|
| 15 |
+
for setting in "${settings[@]}"
|
| 16 |
+
do
|
| 17 |
+
finetuned_model_path=${maindir}model/${model}_${setting}/
|
| 18 |
+
case ${model} in
|
| 19 |
+
"vicuna-33b")
|
| 20 |
+
RAYGPUS=2
|
| 21 |
+
;;
|
| 22 |
+
"t5-3b"|"vicuna-7b"|"vicuna-13b")
|
| 23 |
+
RAYGPUS=1
|
| 24 |
+
;;
|
| 25 |
+
esac
|
| 26 |
+
python3 ${codedir}/codes/eval/get_model_infer_memochat.py \
|
| 27 |
+
--model-path ${finetuned_model_path} \
|
| 28 |
+
--question-file ${test_data} \
|
| 29 |
+
--answer-file ${datadir}/mtbenchplus/mtbenchplus_testing/mtbenchplus_testing_${model}_${setting}.json \
|
| 30 |
+
--num-gpus $GPU_NUM_PER_NODE \
|
| 31 |
+
--ray-num-gpus ${RAYGPUS} \
|
| 32 |
+
--prompt-path ${datadir}/prompts.json
|
| 33 |
+
done
|
| 34 |
+
done
|
baselines/MemoChat/code/scripts/memochat_gpt.sh
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export GLOO_SOCKET_IFNAME=eth0
|
| 2 |
+
export WANDB_MODE=disabled
|
| 3 |
+
|
| 4 |
+
maindir=$1
|
| 5 |
+
datadir=${maindir}data
|
| 6 |
+
codedir=${maindir}code
|
| 7 |
+
|
| 8 |
+
gpt_settings=("2k" "memochat")
|
| 9 |
+
|
| 10 |
+
for gpt_setting in "${gpt_settings[@]}"
|
| 11 |
+
do
|
| 12 |
+
python3 ${codedir}/codes/api/gpt_${gpt_setting}.py \
|
| 13 |
+
${datadir}/mtbenchplus/mtbenchplus.json \
|
| 14 |
+
gpt-3.5-turbo \
|
| 15 |
+
YourOpenAIKey \
|
| 16 |
+
${datadir}/mtbenchplus/mtbenchplus_testing/mtbenchplus_testing_gpt-3.5-turbo-${gpt_setting}.json \
|
| 17 |
+
${datadir}/prompts.json
|
| 18 |
+
done
|
baselines/MemoChat/code/scripts/tuning.sh
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export GLOO_SOCKET_IFNAME=eth0
|
| 2 |
+
export WANDB_MODE=disabled
|
| 3 |
+
|
| 4 |
+
maindir=$1
|
| 5 |
+
datadir=${maindir}data
|
| 6 |
+
codedir=${maindir}code
|
| 7 |
+
|
| 8 |
+
MAXLEN=2048
|
| 9 |
+
EPOCH=3
|
| 10 |
+
test_data=${datadir}/memochat_instructions/test.jsonl
|
| 11 |
+
|
| 12 |
+
settings=("1k" "10k")
|
| 13 |
+
models=("t5-3b" "vicuna-7b" "vicuna-13b" "vicuna-33b")
|
| 14 |
+
|
| 15 |
+
for model in "${models[@]}"
|
| 16 |
+
do
|
| 17 |
+
|
| 18 |
+
raw_model_path=${maindir}model/fastchat-${model}/
|
| 19 |
+
case ${model} in
|
| 20 |
+
"vicuna-33b")
|
| 21 |
+
RAYGPUS=2
|
| 22 |
+
;;
|
| 23 |
+
"t5-3b"|"vicuna-7b"|"vicuna-13b")
|
| 24 |
+
RAYGPUS=1
|
| 25 |
+
;;
|
| 26 |
+
esac
|
| 27 |
+
|
| 28 |
+
# zeroshot inference on one node
|
| 29 |
+
python3 ${codedir}/codes/eval/get_model_infer_simple.py \
|
| 30 |
+
--model-id ${model}_zeroshot \
|
| 31 |
+
--model-path ${raw_model_path} \
|
| 32 |
+
--question-file ${test_data} \
|
| 33 |
+
--answer-file ${datadir}/instruction_testing/instruction_testing_${model}_zeroshot.jsonl \
|
| 34 |
+
--num-gpus $GPU_NUM_PER_NODE \
|
| 35 |
+
--ray-num-gpus ${RAYGPUS}
|
| 36 |
+
|
| 37 |
+
# tuning
|
| 38 |
+
for setting in "${settings[@]}"
|
| 39 |
+
do
|
| 40 |
+
data_path=${datadir}/memochat_instructions/train_${setting}.json
|
| 41 |
+
preprocessed_data_dir=${datadir}/memochat_instructions/processed_${setting}_${model%-*}.pt
|
| 42 |
+
model_output_path=${maindir}model/${model}_${setting}/
|
| 43 |
+
deepspeed_config_path=${codedir}/configs/ds_config_${model#*-}.json
|
| 44 |
+
|
| 45 |
+
case ${model} in
|
| 46 |
+
"t5-3b")
|
| 47 |
+
PER_GPU_BATCH=8
|
| 48 |
+
GRA_ACC=2
|
| 49 |
+
;;
|
| 50 |
+
"vicuna-7b")
|
| 51 |
+
PER_GPU_BATCH=16
|
| 52 |
+
GRA_ACC=1
|
| 53 |
+
;;
|
| 54 |
+
"vicuna-13b")
|
| 55 |
+
PER_GPU_BATCH=8
|
| 56 |
+
GRA_ACC=2
|
| 57 |
+
;;
|
| 58 |
+
"vicuna-33b")
|
| 59 |
+
PER_GPU_BATCH=4
|
| 60 |
+
GRA_ACC=4
|
| 61 |
+
;;
|
| 62 |
+
esac
|
| 63 |
+
|
| 64 |
+
# train data preprocess
|
| 65 |
+
python3 ${codedir}/codes/train/data_preprocess.py \
|
| 66 |
+
--model_name_or_path ${raw_model_path} \
|
| 67 |
+
--data_path ${data_path} \
|
| 68 |
+
--preprocessing_num_workers=1 \
|
| 69 |
+
--model_max_length ${MAXLEN} \
|
| 70 |
+
--preprocessed_path ${preprocessed_data_dir}
|
| 71 |
+
|
| 72 |
+
# training: avaliable for multi nodes
|
| 73 |
+
torchrun --nnodes=$NODE_NUM \
|
| 74 |
+
--node_rank=$INDEX \
|
| 75 |
+
--nproc_per_node $GPU_NUM_PER_NODE \
|
| 76 |
+
--master_addr $MASTER_ADDR \
|
| 77 |
+
--master_port $MASTER_PORT \
|
| 78 |
+
${codedir}/codes/train/train.py \
|
| 79 |
+
--model_name_or_path ${raw_model_path} \
|
| 80 |
+
--bf16 True \
|
| 81 |
+
--output_dir ${model_output_path} \
|
| 82 |
+
--num_train_epochs ${EPOCH} \
|
| 83 |
+
--per_device_train_batch_size ${PER_GPU_BATCH} \
|
| 84 |
+
--gradient_accumulation_steps ${GRA_ACC} \
|
| 85 |
+
--save_strategy "steps" \
|
| 86 |
+
--save_steps 1500 \
|
| 87 |
+
--save_total_limit 1 \
|
| 88 |
+
--learning_rate 2e-5 \
|
| 89 |
+
--log_level "info" \
|
| 90 |
+
--logging_strategy "steps" \
|
| 91 |
+
--logging_steps 1 \
|
| 92 |
+
--weight_decay 0. \
|
| 93 |
+
--warmup_ratio 0.04 \
|
| 94 |
+
--lr_scheduler_type "cosine" \
|
| 95 |
+
--deepspeed ${deepspeed_config_path} \
|
| 96 |
+
--tf32 True \
|
| 97 |
+
--model_max_length ${MAXLEN} \
|
| 98 |
+
--preprocessed_path ${preprocessed_data_dir} \
|
| 99 |
+
--gradient_checkpointing True
|
| 100 |
+
|
| 101 |
+
# tuning inference
|
| 102 |
+
python3 ${codedir}/codes/eval/get_model_infer_simple.py \
|
| 103 |
+
--model-id ${model}_${setting} \
|
| 104 |
+
--model-path ${model_output_path} \
|
| 105 |
+
--question-file ${test_data} \
|
| 106 |
+
--answer-file ${datadir}/instruction_testing/instruction_testing_${model}_${setting}.jsonl \
|
| 107 |
+
--num-gpus $GPU_NUM_PER_NODE \
|
| 108 |
+
--ray-num-gpus ${RAYGPUS}
|
| 109 |
+
done
|
| 110 |
+
done
|
baselines/MemoChat/core_requirement.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==0.19.0
|
| 2 |
+
datasets==2.10.1
|
| 3 |
+
deepspeed==0.9.4
|
| 4 |
+
evaluate==0.4.0
|
| 5 |
+
Faker==18.11.2
|
| 6 |
+
openai==0.27.2
|
| 7 |
+
optimum==1.9.1
|
| 8 |
+
ray==2.5.1
|
| 9 |
+
tiktoken==0.4.0
|
| 10 |
+
tokenizers==0.13.2
|
| 11 |
+
torch==2.0.1
|
| 12 |
+
torchtext==0.15.2
|
| 13 |
+
transformers==4.29.2
|
baselines/MemoChat/run_memochat_baseline.py
ADDED
|
@@ -0,0 +1,634 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MemoChat baseline for the EvolV-Mem benchmark.
|
| 3 |
+
|
| 4 |
+
Adapts MemoChat's three-stage pipeline (memo writing, retrieval, chatting)
|
| 5 |
+
to the EvolV-Mem benchmark using Qwen-30B via vLLM.
|
| 6 |
+
|
| 7 |
+
Pipeline per question:
|
| 8 |
+
1. Memo writing: extract {topic, summary} from each haystack session (cached)
|
| 9 |
+
2. Embedding pre-filter: SBert selects top-50 memos by similarity
|
| 10 |
+
3. MemoChat retrieval: LLM selects final relevant topics from top-50
|
| 11 |
+
4. Answer generation: LLM generates answer from retrieved memos
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python baselines/MemoChat/run_memochat_baseline.py \
|
| 15 |
+
--in_file dataset/evolv_mem_v4.json \
|
| 16 |
+
--out_file output/memochat_qwen30b_v4.jsonl \
|
| 17 |
+
--sessions_file dataset/all_sessions.json \
|
| 18 |
+
--profile_file metadata/generated_user_profile.json
|
| 19 |
+
|
| 20 |
+
Env vars:
|
| 21 |
+
VLLM_BASE_URL (default http://localhost:8000/v1)
|
| 22 |
+
VLLM_API_KEY (default EMPTY)
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import argparse
|
| 26 |
+
import json
|
| 27 |
+
import logging
|
| 28 |
+
import os
|
| 29 |
+
import re
|
| 30 |
+
import sys
|
| 31 |
+
import time
|
| 32 |
+
from collections import defaultdict
|
| 33 |
+
from typing import Dict, List, Optional, Tuple
|
| 34 |
+
|
| 35 |
+
import numpy as np
|
| 36 |
+
from tqdm import tqdm
|
| 37 |
+
|
| 38 |
+
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
|
| 39 |
+
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
# Load MemoChat prompts
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 44 |
+
PROMPTS_PATH = os.path.join(SCRIPT_DIR, "data", "prompts.json")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
# vLLM LLM helper
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
def get_llm_client():
|
| 52 |
+
from openai import OpenAI
|
| 53 |
+
return OpenAI(
|
| 54 |
+
base_url=os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1"),
|
| 55 |
+
api_key=os.getenv("VLLM_API_KEY", "EMPTY"),
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
MODEL_NAME = os.getenv("VLLM_MODEL_NAME", "Qwen/Qwen3-30B-A3B-Instruct-2507")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def llm_call(client, prompt: str, max_tokens: int = 4096, temperature: float = 0.2) -> str:
|
| 63 |
+
"""Call the vLLM server with retry logic."""
|
| 64 |
+
for attempt in range(6):
|
| 65 |
+
try:
|
| 66 |
+
response = client.chat.completions.create(
|
| 67 |
+
model=MODEL_NAME,
|
| 68 |
+
messages=[{"role": "user", "content": prompt}],
|
| 69 |
+
max_tokens=max_tokens,
|
| 70 |
+
temperature=temperature,
|
| 71 |
+
)
|
| 72 |
+
content = response.choices[0].message.content if response.choices else None
|
| 73 |
+
if content is None:
|
| 74 |
+
wait = min(2 ** attempt * 2, 30)
|
| 75 |
+
print(f"[WARN] LLM returned None content (attempt {attempt+1}); retrying in {wait}s")
|
| 76 |
+
time.sleep(wait)
|
| 77 |
+
continue
|
| 78 |
+
return content.strip()
|
| 79 |
+
except Exception as e:
|
| 80 |
+
msg = str(e).lower()
|
| 81 |
+
if any(code in msg for code in ("429", "500", "503", "rate limit")):
|
| 82 |
+
wait = min(2 ** attempt * 5, 60)
|
| 83 |
+
print(f"[WARN] LLM retry {attempt+1}/6, sleeping {wait}s: {e}")
|
| 84 |
+
time.sleep(wait)
|
| 85 |
+
continue
|
| 86 |
+
print(f"[ERROR] LLM call failed: {e}")
|
| 87 |
+
raise
|
| 88 |
+
raise RuntimeError("LLM call failed after 6 retries")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# ---------------------------------------------------------------------------
|
| 92 |
+
# MemoChat output parsers (ported from get_model_infer_memochat.py)
|
| 93 |
+
# ---------------------------------------------------------------------------
|
| 94 |
+
|
| 95 |
+
def normalize_model_outputs(model_text: str) -> List[Dict]:
|
| 96 |
+
"""Parse memo writing output into structured topic-summary dicts."""
|
| 97 |
+
extracted_elements = [
|
| 98 |
+
re.sub(r'\s+', ' ', mt.replace('"', '').replace("'", ""))
|
| 99 |
+
for mt in re.findall(r"'[^']*'|\"[^\"]*\"|\d+", model_text)
|
| 100 |
+
]
|
| 101 |
+
model_outputs = []
|
| 102 |
+
ti = 0
|
| 103 |
+
while ti + 7 < len(extracted_elements):
|
| 104 |
+
if (extracted_elements[ti] == "topic"
|
| 105 |
+
and extracted_elements[ti + 2] == "summary"
|
| 106 |
+
and extracted_elements[ti + 4] == "start"
|
| 107 |
+
and extracted_elements[ti + 6] == "end"):
|
| 108 |
+
try:
|
| 109 |
+
model_outputs.append({
|
| 110 |
+
"topic": extracted_elements[ti + 1],
|
| 111 |
+
"summary": extracted_elements[ti + 3],
|
| 112 |
+
"start": int(extracted_elements[ti + 5]),
|
| 113 |
+
"end": int(extracted_elements[ti + 7]),
|
| 114 |
+
})
|
| 115 |
+
except (ValueError, IndexError):
|
| 116 |
+
pass
|
| 117 |
+
ti += 1
|
| 118 |
+
return model_outputs
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def normalize_chatting_outputs(model_outputs: str) -> str:
|
| 122 |
+
"""Clean up chatting response whitespace."""
|
| 123 |
+
lines = model_outputs.split("\n")
|
| 124 |
+
result = [' '.join(line.split()) for line in lines]
|
| 125 |
+
return '\n'.join(result)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# ---------------------------------------------------------------------------
|
| 129 |
+
# Stage 1: Memo Writing (per session, cached)
|
| 130 |
+
# ---------------------------------------------------------------------------
|
| 131 |
+
|
| 132 |
+
def format_session_for_writing(session_turns: List[Dict]) -> str:
|
| 133 |
+
"""Format a session's turns as numbered lines for MemoChat's writing prompt."""
|
| 134 |
+
lines = []
|
| 135 |
+
for i, turn in enumerate(session_turns):
|
| 136 |
+
role = turn.get("role", "user")
|
| 137 |
+
content = turn.get("content", "").replace("\n", " ")
|
| 138 |
+
lines.append(f"(line {i + 1}) {role}: {content}")
|
| 139 |
+
return "\n".join(lines)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def write_memos_for_session(
|
| 143 |
+
session_id: str,
|
| 144 |
+
session_turns: List[Dict],
|
| 145 |
+
client,
|
| 146 |
+
prompts: Dict,
|
| 147 |
+
memo_cache_dir: str,
|
| 148 |
+
) -> List[Dict]:
|
| 149 |
+
"""Extract topic-summary memos from a single session using MemoChat's writing prompt.
|
| 150 |
+
|
| 151 |
+
Returns list of {topic, summary, start, end, session_id}.
|
| 152 |
+
Results are cached to disk.
|
| 153 |
+
"""
|
| 154 |
+
cache_path = os.path.join(memo_cache_dir, f"{session_id}.json")
|
| 155 |
+
if os.path.exists(cache_path):
|
| 156 |
+
with open(cache_path) as f:
|
| 157 |
+
return json.load(f)
|
| 158 |
+
|
| 159 |
+
if not session_turns:
|
| 160 |
+
memos = []
|
| 161 |
+
with open(cache_path, "w") as f:
|
| 162 |
+
json.dump(memos, f)
|
| 163 |
+
return memos
|
| 164 |
+
|
| 165 |
+
# Build MemoChat writing prompt
|
| 166 |
+
num_lines = len(session_turns)
|
| 167 |
+
system_instruction = prompts["writing_dialogsum"]["system"]
|
| 168 |
+
task_instruction = prompts["writing_dialogsum"]["instruction"]
|
| 169 |
+
|
| 170 |
+
history_log = "\n\n```\nTask Conversation:\n" + format_session_for_writing(session_turns)
|
| 171 |
+
prompt = (
|
| 172 |
+
system_instruction.replace("LINE", str(num_lines))
|
| 173 |
+
+ history_log
|
| 174 |
+
+ "\n```"
|
| 175 |
+
+ task_instruction.replace("LINE", str(num_lines))
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Call LLM
|
| 179 |
+
output = llm_call(client, prompt, max_tokens=512, temperature=0.2)
|
| 180 |
+
memos_raw = normalize_model_outputs(output)
|
| 181 |
+
|
| 182 |
+
# Attach session_id to each memo
|
| 183 |
+
memos = []
|
| 184 |
+
for m in memos_raw:
|
| 185 |
+
memos.append({
|
| 186 |
+
"topic": m["topic"],
|
| 187 |
+
"summary": m["summary"],
|
| 188 |
+
"start": m["start"],
|
| 189 |
+
"end": m["end"],
|
| 190 |
+
"session_id": session_id,
|
| 191 |
+
})
|
| 192 |
+
|
| 193 |
+
# If no memos extracted, create a fallback from the session content
|
| 194 |
+
if not memos:
|
| 195 |
+
# Use first and last turn as a basic summary
|
| 196 |
+
first_content = session_turns[0].get("content", "")[:200] if session_turns else ""
|
| 197 |
+
memos.append({
|
| 198 |
+
"topic": f"session_{session_id}",
|
| 199 |
+
"summary": f"Conversation about: {first_content}...",
|
| 200 |
+
"start": 1,
|
| 201 |
+
"end": num_lines,
|
| 202 |
+
"session_id": session_id,
|
| 203 |
+
})
|
| 204 |
+
|
| 205 |
+
# Cache
|
| 206 |
+
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
|
| 207 |
+
with open(cache_path, "w") as f:
|
| 208 |
+
json.dump(memos, f)
|
| 209 |
+
|
| 210 |
+
return memos
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# ---------------------------------------------------------------------------
|
| 214 |
+
# Stage 1 (fast): Build memos from pre-computed session summaries
|
| 215 |
+
# ---------------------------------------------------------------------------
|
| 216 |
+
|
| 217 |
+
def build_memos_from_summaries(
|
| 218 |
+
haystack_session_ids: List[str],
|
| 219 |
+
haystack_dates: List[str],
|
| 220 |
+
summaries: Dict,
|
| 221 |
+
) -> List[Dict]:
|
| 222 |
+
"""Build memo entries directly from all_session_summary.json — no LLM calls."""
|
| 223 |
+
memos = []
|
| 224 |
+
for sid, date_str in zip(haystack_session_ids, haystack_dates):
|
| 225 |
+
summary_data = summaries.get(sid)
|
| 226 |
+
if summary_data is None:
|
| 227 |
+
continue
|
| 228 |
+
text = summary_data.get("session_summary", "")
|
| 229 |
+
if not text:
|
| 230 |
+
turn_sums = summary_data.get("turn_summaries", [])
|
| 231 |
+
if turn_sums:
|
| 232 |
+
text = " ".join(turn_sums)
|
| 233 |
+
else:
|
| 234 |
+
continue
|
| 235 |
+
# Use first ~60 chars as topic, full text as summary
|
| 236 |
+
topic = text[:60].rstrip(". ") if len(text) > 60 else text
|
| 237 |
+
memos.append({
|
| 238 |
+
"topic": topic,
|
| 239 |
+
"summary": text,
|
| 240 |
+
"session_id": sid,
|
| 241 |
+
"session_date": date_str,
|
| 242 |
+
})
|
| 243 |
+
return memos
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# ---------------------------------------------------------------------------
|
| 247 |
+
# Stage 2: Embedding pre-filter
|
| 248 |
+
# ---------------------------------------------------------------------------
|
| 249 |
+
|
| 250 |
+
def embed_and_filter(
|
| 251 |
+
question: str,
|
| 252 |
+
all_memos: List[Dict],
|
| 253 |
+
embedding_model,
|
| 254 |
+
top_k: int = 50,
|
| 255 |
+
) -> List[Dict]:
|
| 256 |
+
"""Use SBert to select top-k most relevant memos by cosine similarity."""
|
| 257 |
+
if len(all_memos) <= top_k:
|
| 258 |
+
return all_memos
|
| 259 |
+
|
| 260 |
+
# Build texts to embed
|
| 261 |
+
memo_texts = [f"{m['topic']}. {m['summary']}" for m in all_memos]
|
| 262 |
+
|
| 263 |
+
# Encode
|
| 264 |
+
question_emb = embedding_model.encode(question)
|
| 265 |
+
memo_embs = embedding_model.encode(memo_texts)
|
| 266 |
+
|
| 267 |
+
# Cosine similarity
|
| 268 |
+
question_norm = question_emb / (np.linalg.norm(question_emb) + 1e-10)
|
| 269 |
+
memo_norms = memo_embs / (np.linalg.norm(memo_embs, axis=1, keepdims=True) + 1e-10)
|
| 270 |
+
similarities = memo_norms @ question_norm
|
| 271 |
+
|
| 272 |
+
# Top-k indices
|
| 273 |
+
top_indices = np.argsort(similarities)[::-1][:top_k]
|
| 274 |
+
return [all_memos[i] for i in top_indices]
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
# ---------------------------------------------------------------------------
|
| 278 |
+
# Stage 3: MemoChat LLM-based retrieval
|
| 279 |
+
# ---------------------------------------------------------------------------
|
| 280 |
+
|
| 281 |
+
def memochat_retrieve(
|
| 282 |
+
question: str,
|
| 283 |
+
candidate_memos: List[Dict],
|
| 284 |
+
client,
|
| 285 |
+
prompts: Dict,
|
| 286 |
+
) -> List[Dict]:
|
| 287 |
+
"""Apply MemoChat's retrieval prompt to select relevant memos from candidates."""
|
| 288 |
+
if not candidate_memos:
|
| 289 |
+
return []
|
| 290 |
+
|
| 291 |
+
# Build topic options list
|
| 292 |
+
topic_options = []
|
| 293 |
+
for i, m in enumerate(candidate_memos):
|
| 294 |
+
topic_options.append(f"({i + 1}) {m['topic']}. {m['summary']}")
|
| 295 |
+
|
| 296 |
+
# Add NOTO option
|
| 297 |
+
noto_idx = len(candidate_memos) + 1
|
| 298 |
+
topic_options.append(f"({noto_idx}) NOTO. None of the others.")
|
| 299 |
+
|
| 300 |
+
system_instruction = prompts["retrieval"]["system"]
|
| 301 |
+
task_instruction = prompts["retrieval"]["instruction"]
|
| 302 |
+
|
| 303 |
+
task_case = (
|
| 304 |
+
"```\nQuery Sentence:\n" + question
|
| 305 |
+
+ "\nTopic Options:\n" + "\n".join(topic_options)
|
| 306 |
+
+ "\n```"
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
prompt = (
|
| 310 |
+
system_instruction.replace("OPTION", str(len(candidate_memos) + 1))
|
| 311 |
+
+ task_case
|
| 312 |
+
+ task_instruction.replace("OPTION", str(len(candidate_memos) + 1))
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
output = llm_call(client, prompt, max_tokens=2048, temperature=0.2)
|
| 316 |
+
|
| 317 |
+
# Parse selected indices
|
| 318 |
+
selected_memos = []
|
| 319 |
+
for part in output.split("#"):
|
| 320 |
+
part = part.strip()
|
| 321 |
+
try:
|
| 322 |
+
idx = int(part) - 1
|
| 323 |
+
if 0 <= idx < len(candidate_memos):
|
| 324 |
+
selected_memos.append(candidate_memos[idx])
|
| 325 |
+
except ValueError:
|
| 326 |
+
continue
|
| 327 |
+
|
| 328 |
+
return selected_memos
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# ---------------------------------------------------------------------------
|
| 332 |
+
# Stage 4: Answer generation with MemoChat chatting prompt
|
| 333 |
+
# ---------------------------------------------------------------------------
|
| 334 |
+
|
| 335 |
+
def generate_answer(
|
| 336 |
+
question: str,
|
| 337 |
+
question_date: str,
|
| 338 |
+
retrieved_memos: List[Dict],
|
| 339 |
+
user_profile: Optional[str],
|
| 340 |
+
client,
|
| 341 |
+
prompts: Dict,
|
| 342 |
+
) -> str:
|
| 343 |
+
"""Generate an answer using MemoChat's chatting prompt format."""
|
| 344 |
+
system_instruction = prompts["chatting"]["system"]
|
| 345 |
+
|
| 346 |
+
# Build "Related Evidences" section from retrieved memos
|
| 347 |
+
evidence_lines = []
|
| 348 |
+
for i, m in enumerate(retrieved_memos):
|
| 349 |
+
evidence_lines.append(
|
| 350 |
+
f"({i + 1}) {{'Related Topics': '{m['topic']}', "
|
| 351 |
+
f"'Related Summaries': '{m['summary']}'}}"
|
| 352 |
+
)
|
| 353 |
+
evidences_str = "\n".join(evidence_lines) if evidence_lines else "(No related evidences found)"
|
| 354 |
+
|
| 355 |
+
# Build the chatting prompt
|
| 356 |
+
profile_section = ""
|
| 357 |
+
if user_profile:
|
| 358 |
+
profile_section = f"\nUser Profile:\n{user_profile}\n"
|
| 359 |
+
|
| 360 |
+
task_case = (
|
| 361 |
+
f"```\nRelated Evidences:\n{evidences_str}"
|
| 362 |
+
f"\n\nRecent Dialogs:\n(no recent dialogs)"
|
| 363 |
+
f"\n```"
|
| 364 |
+
f"{profile_section}"
|
| 365 |
+
f"\n\nCurrent Date: {question_date}"
|
| 366 |
+
f"\n\nUser Input:\nuser: {question} ### bot: "
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
prompt = system_instruction + task_case
|
| 370 |
+
|
| 371 |
+
output = llm_call(client, prompt, max_tokens=8192, temperature=0.2)
|
| 372 |
+
return normalize_chatting_outputs(output)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
# ---------------------------------------------------------------------------
|
| 376 |
+
# Main
|
| 377 |
+
# ---------------------------------------------------------------------------
|
| 378 |
+
|
| 379 |
+
# ---------------------------------------------------------------------------
|
| 380 |
+
# Retrieval metrics
|
| 381 |
+
# ---------------------------------------------------------------------------
|
| 382 |
+
|
| 383 |
+
def evaluate_retrieval(recalled_docs, correct_docs):
|
| 384 |
+
recall_any = float(any(doc in recalled_docs for doc in correct_docs))
|
| 385 |
+
recall_all = float(all(doc in recalled_docs for doc in correct_docs))
|
| 386 |
+
return recall_any, recall_all
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def print_average_metrics(retrieval_metric_list):
|
| 390 |
+
metric_sums = defaultdict(float)
|
| 391 |
+
metric_counts = defaultdict(int)
|
| 392 |
+
for metric in retrieval_metric_list:
|
| 393 |
+
for k, v in metric.items():
|
| 394 |
+
metric_sums[k] += v
|
| 395 |
+
metric_counts[k] += 1
|
| 396 |
+
print(" Average retrieval metrics:")
|
| 397 |
+
for k in sorted(metric_sums):
|
| 398 |
+
avg = metric_sums[k] / metric_counts[k]
|
| 399 |
+
print(f" {k}: {avg:.4f}")
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
# ---------------------------------------------------------------------------
|
| 403 |
+
# Main
|
| 404 |
+
# ---------------------------------------------------------------------------
|
| 405 |
+
|
| 406 |
+
def main():
|
| 407 |
+
parser = argparse.ArgumentParser(description="MemoChat baseline for EvolV-Mem")
|
| 408 |
+
parser.add_argument("--in_file", type=str, required=True,
|
| 409 |
+
help="Path to evolv_mem_v4.json")
|
| 410 |
+
parser.add_argument("--out_file", type=str, required=True,
|
| 411 |
+
help="Output JSONL file")
|
| 412 |
+
parser.add_argument("--sessions_file", type=str, default=None,
|
| 413 |
+
help="Path to all_sessions.json (only needed with --use_llm_memos)")
|
| 414 |
+
parser.add_argument("--summary_file", type=str, default=None,
|
| 415 |
+
help="Path to all_session_summary.json (used by default for memo bank)")
|
| 416 |
+
parser.add_argument("--profile_file", type=str, default=None,
|
| 417 |
+
help="Path to generated_user_profile.json")
|
| 418 |
+
parser.add_argument("--memo_cache_dir", type=str,
|
| 419 |
+
default="baselines/MemoChat/memo_cache",
|
| 420 |
+
help="Directory to cache per-session memos")
|
| 421 |
+
parser.add_argument("--prompt_file", type=str, default=None,
|
| 422 |
+
help="Path to prompts.json (default: baselines/MemoChat/data/prompts.json)")
|
| 423 |
+
# Retrieval params
|
| 424 |
+
parser.add_argument("--embed_top_k", type=int, default=50,
|
| 425 |
+
help="Number of memos to keep after embedding pre-filter (default 50)")
|
| 426 |
+
parser.add_argument("--embedding_model", type=str,
|
| 427 |
+
default="sentence-transformers/multi-qa-mpnet-base-cos-v1",
|
| 428 |
+
help="SentenceTransformer model for embedding pre-filter")
|
| 429 |
+
# Limit (for debugging)
|
| 430 |
+
parser.add_argument("--limit", type=int, default=None,
|
| 431 |
+
help="Process only the first N questions")
|
| 432 |
+
parser.add_argument("--use_llm_memos", action="store_true", default=False,
|
| 433 |
+
help="Use LLM-based memo writing instead of cached session summaries")
|
| 434 |
+
args = parser.parse_args()
|
| 435 |
+
|
| 436 |
+
# -----------------------------------------------------------------------
|
| 437 |
+
# Load data
|
| 438 |
+
# -----------------------------------------------------------------------
|
| 439 |
+
print(f"Loading benchmark from {args.in_file} ...")
|
| 440 |
+
with open(args.in_file) as f:
|
| 441 |
+
benchmark = json.load(f)
|
| 442 |
+
if args.limit:
|
| 443 |
+
benchmark = benchmark[:args.limit]
|
| 444 |
+
print(f" {len(benchmark)} questions loaded.")
|
| 445 |
+
|
| 446 |
+
all_sessions = {}
|
| 447 |
+
if args.sessions_file and os.path.exists(args.sessions_file):
|
| 448 |
+
print(f"Loading sessions from {args.sessions_file} ...")
|
| 449 |
+
with open(args.sessions_file) as f:
|
| 450 |
+
all_sessions = json.load(f)
|
| 451 |
+
print(f" {len(all_sessions)} sessions loaded.")
|
| 452 |
+
|
| 453 |
+
summaries = {}
|
| 454 |
+
if args.summary_file and os.path.exists(args.summary_file):
|
| 455 |
+
print(f"Loading session summaries from {args.summary_file} ...")
|
| 456 |
+
with open(args.summary_file) as f:
|
| 457 |
+
summaries = json.load(f)
|
| 458 |
+
print(f" {len(summaries)} session summaries loaded.")
|
| 459 |
+
|
| 460 |
+
if not args.use_llm_memos and not summaries:
|
| 461 |
+
print("ERROR: --summary_file is required unless --use_llm_memos is set.")
|
| 462 |
+
sys.exit(1)
|
| 463 |
+
if args.use_llm_memos and not all_sessions:
|
| 464 |
+
print("ERROR: --sessions_file is required when --use_llm_memos is set.")
|
| 465 |
+
sys.exit(1)
|
| 466 |
+
|
| 467 |
+
profiles = {}
|
| 468 |
+
if args.profile_file and os.path.exists(args.profile_file):
|
| 469 |
+
print(f"Loading user profiles from {args.profile_file} ...")
|
| 470 |
+
with open(args.profile_file) as f:
|
| 471 |
+
profiles = json.load(f)
|
| 472 |
+
print(f" {len(profiles)} profiles loaded.")
|
| 473 |
+
|
| 474 |
+
prompt_file = args.prompt_file or PROMPTS_PATH
|
| 475 |
+
print(f"Loading prompts from {prompt_file} ...")
|
| 476 |
+
with open(prompt_file) as f:
|
| 477 |
+
prompts = json.load(f)
|
| 478 |
+
|
| 479 |
+
# -----------------------------------------------------------------------
|
| 480 |
+
# Resume support
|
| 481 |
+
# -----------------------------------------------------------------------
|
| 482 |
+
existing_qids = set()
|
| 483 |
+
if os.path.exists(args.out_file):
|
| 484 |
+
with open(args.out_file) as f:
|
| 485 |
+
for line in f:
|
| 486 |
+
line = line.strip()
|
| 487 |
+
if line:
|
| 488 |
+
obj = json.loads(line)
|
| 489 |
+
existing_qids.add(obj["question_id"])
|
| 490 |
+
print(f" Resuming: {len(existing_qids)} questions already processed.")
|
| 491 |
+
|
| 492 |
+
# -----------------------------------------------------------------------
|
| 493 |
+
# Initialize models
|
| 494 |
+
# -----------------------------------------------------------------------
|
| 495 |
+
print("Initializing embedding model ...")
|
| 496 |
+
from sentence_transformers import SentenceTransformer
|
| 497 |
+
embedding_model = SentenceTransformer(args.embedding_model)
|
| 498 |
+
|
| 499 |
+
print("Initializing vLLM client ...")
|
| 500 |
+
client = get_llm_client()
|
| 501 |
+
|
| 502 |
+
os.makedirs(args.memo_cache_dir, exist_ok=True)
|
| 503 |
+
|
| 504 |
+
# -----------------------------------------------------------------------
|
| 505 |
+
# Process questions
|
| 506 |
+
# -----------------------------------------------------------------------
|
| 507 |
+
retrieval_metric_list = []
|
| 508 |
+
out_f = open(args.out_file, "a")
|
| 509 |
+
|
| 510 |
+
for di, entry in enumerate(tqdm(benchmark, desc="MemoChat baseline")):
|
| 511 |
+
qid = entry["question_id"]
|
| 512 |
+
question = entry["question"]
|
| 513 |
+
question_date = entry["question_date"]
|
| 514 |
+
|
| 515 |
+
if qid in existing_qids:
|
| 516 |
+
continue
|
| 517 |
+
|
| 518 |
+
try:
|
| 519 |
+
haystack_session_ids = entry["haystack_session_ids"]
|
| 520 |
+
|
| 521 |
+
# ------ Stage 1: Build memo bank ------
|
| 522 |
+
if args.use_llm_memos:
|
| 523 |
+
# Slow path: LLM-based memo writing (cached per session)
|
| 524 |
+
all_memos = []
|
| 525 |
+
n_cached = 0
|
| 526 |
+
n_written = 0
|
| 527 |
+
date_lookup = dict(zip(
|
| 528 |
+
entry["haystack_session_ids"], entry["haystack_dates"]
|
| 529 |
+
))
|
| 530 |
+
for sid in haystack_session_ids:
|
| 531 |
+
session_turns = all_sessions.get(sid, [])
|
| 532 |
+
cache_exists = os.path.exists(
|
| 533 |
+
os.path.join(args.memo_cache_dir, f"{sid}.json")
|
| 534 |
+
)
|
| 535 |
+
memos = write_memos_for_session(
|
| 536 |
+
sid, session_turns, client, prompts, args.memo_cache_dir
|
| 537 |
+
)
|
| 538 |
+
for m in memos:
|
| 539 |
+
m["session_date"] = date_lookup.get(sid, "")
|
| 540 |
+
all_memos.extend(memos)
|
| 541 |
+
if cache_exists:
|
| 542 |
+
n_cached += 1
|
| 543 |
+
else:
|
| 544 |
+
n_written += 1
|
| 545 |
+
print(f" [{di}] qid={qid}: {len(all_memos)} memos "
|
| 546 |
+
f"({n_cached} cached, {n_written} new)")
|
| 547 |
+
else:
|
| 548 |
+
# Fast path: use pre-computed session summaries as memos
|
| 549 |
+
all_memos = build_memos_from_summaries(
|
| 550 |
+
haystack_session_ids, entry["haystack_dates"], summaries
|
| 551 |
+
)
|
| 552 |
+
print(f" [{di}] qid={qid}: {len(all_memos)} memos from summaries")
|
| 553 |
+
|
| 554 |
+
if not all_memos:
|
| 555 |
+
result = {
|
| 556 |
+
"q_idx": di,
|
| 557 |
+
"question_id": qid,
|
| 558 |
+
"hypothesis": "Insufficient information to answer.",
|
| 559 |
+
"n_memos": 0,
|
| 560 |
+
}
|
| 561 |
+
print(json.dumps(result), file=out_f, flush=True)
|
| 562 |
+
continue
|
| 563 |
+
|
| 564 |
+
# ------ Stage 2: Embedding pre-filter ------
|
| 565 |
+
filtered_memos = embed_and_filter(
|
| 566 |
+
question, all_memos, embedding_model, top_k=args.embed_top_k
|
| 567 |
+
)
|
| 568 |
+
print(f" [{di}] Embedding filter: {len(all_memos)} -> {len(filtered_memos)} memos")
|
| 569 |
+
|
| 570 |
+
# ------ Stage 3: MemoChat LLM retrieval ------
|
| 571 |
+
retrieved_memos = memochat_retrieve(
|
| 572 |
+
question, filtered_memos, client, prompts
|
| 573 |
+
)
|
| 574 |
+
print(f" [{di}] MemoChat retrieval: {len(filtered_memos)} -> {len(retrieved_memos)} memos")
|
| 575 |
+
|
| 576 |
+
# Fallback: if retrieval selected nothing, use top-5 from embedding filter
|
| 577 |
+
if not retrieved_memos:
|
| 578 |
+
retrieved_memos = filtered_memos[:5]
|
| 579 |
+
print(f" [{di}] Fallback: using top-5 from embedding filter")
|
| 580 |
+
|
| 581 |
+
# ------ Stage 4: Answer generation ------
|
| 582 |
+
user_id = qid.split("_q_")[0] if "_q_" in qid else qid
|
| 583 |
+
user_profile = profiles.get(user_id, None)
|
| 584 |
+
|
| 585 |
+
answer = generate_answer(
|
| 586 |
+
question, question_date, retrieved_memos,
|
| 587 |
+
user_profile, client, prompts
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
# ------ Output ------
|
| 591 |
+
retrieved_session_ids = list(dict.fromkeys(
|
| 592 |
+
m["session_id"] for m in retrieved_memos if "session_id" in m
|
| 593 |
+
))
|
| 594 |
+
|
| 595 |
+
# Compute retrieval metrics
|
| 596 |
+
answer_session_ids = entry.get("answer_session_ids", [])
|
| 597 |
+
retrieval_metric = {}
|
| 598 |
+
if answer_session_ids and retrieved_session_ids:
|
| 599 |
+
for topk in [5, 10, 20, 30]:
|
| 600 |
+
r_any, r_all = evaluate_retrieval(
|
| 601 |
+
retrieved_session_ids[:topk], answer_session_ids
|
| 602 |
+
)
|
| 603 |
+
retrieval_metric[f"recall_any@{topk}"] = r_any
|
| 604 |
+
retrieval_metric[f"recall_all@{topk}"] = r_all
|
| 605 |
+
retrieval_metric_list.append(retrieval_metric)
|
| 606 |
+
print_average_metrics(retrieval_metric_list)
|
| 607 |
+
|
| 608 |
+
result = {
|
| 609 |
+
"q_idx": di,
|
| 610 |
+
"question_id": qid,
|
| 611 |
+
"hypothesis": answer,
|
| 612 |
+
"n_memos_total": len(all_memos),
|
| 613 |
+
"n_memos_filtered": len(filtered_memos),
|
| 614 |
+
"n_memos_retrieved": len(retrieved_memos),
|
| 615 |
+
"retrieved_session_ids": retrieved_session_ids,
|
| 616 |
+
"retrieval_metric": retrieval_metric,
|
| 617 |
+
}
|
| 618 |
+
print(json.dumps(result), file=out_f, flush=True)
|
| 619 |
+
|
| 620 |
+
print(f" [{di}] Q: {question[:100]}...")
|
| 621 |
+
print(f" [{di}] A: {answer[:200]}...")
|
| 622 |
+
|
| 623 |
+
except Exception as e:
|
| 624 |
+
print(f"[ERROR] q_idx={di} qid={qid} failed: {e}", flush=True)
|
| 625 |
+
import traceback
|
| 626 |
+
traceback.print_exc()
|
| 627 |
+
continue
|
| 628 |
+
|
| 629 |
+
out_f.close()
|
| 630 |
+
print(f"\nDone. Results saved to {args.out_file}")
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
if __name__ == "__main__":
|
| 634 |
+
main()
|
baselines/raptor/LICENSE.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
The MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) Parth Sarthi
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in
|
| 13 |
+
all copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
| 21 |
+
THE SOFTWARE.
|
baselines/raptor/README.md
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!-- <p align="center">
|
| 2 |
+
<img align="center" src="raptor.jpg" width="1000px" />
|
| 3 |
+
</p>
|
| 4 |
+
<p align="left"> -->
|
| 5 |
+
|
| 6 |
+
<!-- <picture>
|
| 7 |
+
<source media="(prefers-color-scheme: dark)" srcset="raptor.jpg" width="1000px">
|
| 8 |
+
<source media="(prefers-color-scheme: light)" srcset="raptor_dark.png" width="1000px">
|
| 9 |
+
|
| 10 |
+
</picture> -->
|
| 11 |
+
|
| 12 |
+
<picture>
|
| 13 |
+
<source media="(prefers-color-scheme: dark)" srcset="raptor_dark.png">
|
| 14 |
+
<img alt="Shows an illustrated sun in light color mode and a moon with stars in dark color mode." src="raptor.jpg">
|
| 15 |
+
</picture>
|
| 16 |
+
|
| 17 |
+
## RAPTOR: Recursive Abstractive Processing for Tree-Organized Retrieval
|
| 18 |
+
|
| 19 |
+
**RAPTOR** introduces a novel approach to retrieval-augmented language models by constructing a recursive tree structure from documents. This allows for more efficient and context-aware information retrieval across large texts, addressing common limitations in traditional language models.
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
For detailed methodologies and implementations, refer to the original paper:
|
| 24 |
+
|
| 25 |
+
- [RAPTOR: Recursive Abstractive Processing for Tree-Organized Retrieval](https://arxiv.org/abs/2401.18059)
|
| 26 |
+
|
| 27 |
+
[](https://huggingface.co/papers/2401.18059)
|
| 28 |
+
[](https://paperswithcode.com/sota/question-answering-on-quality?p=raptor-recursive-abstractive-processing-for)
|
| 29 |
+
|
| 30 |
+
## Installation
|
| 31 |
+
|
| 32 |
+
Before using RAPTOR, ensure Python 3.8+ is installed. Clone the RAPTOR repository and install necessary dependencies:
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
git clone https://github.com/parthsarthi03/raptor.git
|
| 36 |
+
cd raptor
|
| 37 |
+
pip install -r requirements.txt
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## Basic Usage
|
| 41 |
+
|
| 42 |
+
To get started with RAPTOR, follow these steps:
|
| 43 |
+
|
| 44 |
+
### Setting Up RAPTOR
|
| 45 |
+
|
| 46 |
+
First, set your OpenAI API key and initialize the RAPTOR configuration:
|
| 47 |
+
|
| 48 |
+
```python
|
| 49 |
+
import os
|
| 50 |
+
os.environ["OPENAI_API_KEY"] = "your-openai-api-key"
|
| 51 |
+
|
| 52 |
+
from raptor import RetrievalAugmentation
|
| 53 |
+
|
| 54 |
+
# Initialize with default configuration. For advanced configurations, check the documentation. [WIP]
|
| 55 |
+
RA = RetrievalAugmentation()
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
### Adding Documents to the Tree
|
| 59 |
+
|
| 60 |
+
Add your text documents to RAPTOR for indexing:
|
| 61 |
+
|
| 62 |
+
```python
|
| 63 |
+
with open('sample.txt', 'r') as file:
|
| 64 |
+
text = file.read()
|
| 65 |
+
RA.add_documents(text)
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
### Answering Questions
|
| 69 |
+
|
| 70 |
+
You can now use RAPTOR to answer questions based on the indexed documents:
|
| 71 |
+
|
| 72 |
+
```python
|
| 73 |
+
question = "How did Cinderella reach her happy ending?"
|
| 74 |
+
answer = RA.answer_question(question=question)
|
| 75 |
+
print("Answer: ", answer)
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### Saving and Loading the Tree
|
| 79 |
+
|
| 80 |
+
Save the constructed tree to a specified path:
|
| 81 |
+
|
| 82 |
+
```python
|
| 83 |
+
SAVE_PATH = "demo/cinderella"
|
| 84 |
+
RA.save(SAVE_PATH)
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
Load the saved tree back into RAPTOR:
|
| 88 |
+
|
| 89 |
+
```python
|
| 90 |
+
RA = RetrievalAugmentation(tree=SAVE_PATH)
|
| 91 |
+
answer = RA.answer_question(question=question)
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
### Extending RAPTOR with other Models
|
| 96 |
+
|
| 97 |
+
RAPTOR is designed to be flexible and allows you to integrate any models for summarization, question-answering (QA), and embedding generation. Here is how to extend RAPTOR with your own models:
|
| 98 |
+
|
| 99 |
+
#### Custom Summarization Model
|
| 100 |
+
|
| 101 |
+
If you wish to use a different language model for summarization, you can do so by extending the `BaseSummarizationModel` class. Implement the `summarize` method to integrate your custom summarization logic:
|
| 102 |
+
|
| 103 |
+
```python
|
| 104 |
+
from raptor import BaseSummarizationModel
|
| 105 |
+
|
| 106 |
+
class CustomSummarizationModel(BaseSummarizationModel):
|
| 107 |
+
def __init__(self):
|
| 108 |
+
# Initialize your model here
|
| 109 |
+
pass
|
| 110 |
+
|
| 111 |
+
def summarize(self, context, max_tokens=150):
|
| 112 |
+
# Implement your summarization logic here
|
| 113 |
+
# Return the summary as a string
|
| 114 |
+
summary = "Your summary here"
|
| 115 |
+
return summary
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
#### Custom QA Model
|
| 119 |
+
|
| 120 |
+
For custom QA models, extend the `BaseQAModel` class and implement the `answer_question` method. This method should return the best answer found by your model given a context and a question:
|
| 121 |
+
|
| 122 |
+
```python
|
| 123 |
+
from raptor import BaseQAModel
|
| 124 |
+
|
| 125 |
+
class CustomQAModel(BaseQAModel):
|
| 126 |
+
def __init__(self):
|
| 127 |
+
# Initialize your model here
|
| 128 |
+
pass
|
| 129 |
+
|
| 130 |
+
def answer_question(self, context, question):
|
| 131 |
+
# Implement your QA logic here
|
| 132 |
+
# Return the answer as a string
|
| 133 |
+
answer = "Your answer here"
|
| 134 |
+
return answer
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
#### Custom Embedding Model
|
| 138 |
+
|
| 139 |
+
To use a different embedding model, extend the `BaseEmbeddingModel` class. Implement the `create_embedding` method, which should return a vector representation of the input text:
|
| 140 |
+
|
| 141 |
+
```python
|
| 142 |
+
from raptor import BaseEmbeddingModel
|
| 143 |
+
|
| 144 |
+
class CustomEmbeddingModel(BaseEmbeddingModel):
|
| 145 |
+
def __init__(self):
|
| 146 |
+
# Initialize your model here
|
| 147 |
+
pass
|
| 148 |
+
|
| 149 |
+
def create_embedding(self, text):
|
| 150 |
+
# Implement your embedding logic here
|
| 151 |
+
# Return the embedding as a numpy array or a list of floats
|
| 152 |
+
embedding = [0.0] * embedding_dim # Replace with actual embedding logic
|
| 153 |
+
return embedding
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
#### Integrating Custom Models with RAPTOR
|
| 157 |
+
|
| 158 |
+
After implementing your custom models, integrate them with RAPTOR as follows:
|
| 159 |
+
|
| 160 |
+
```python
|
| 161 |
+
from raptor import RetrievalAugmentation, RetrievalAugmentationConfig
|
| 162 |
+
|
| 163 |
+
# Initialize your custom models
|
| 164 |
+
custom_summarizer = CustomSummarizationModel()
|
| 165 |
+
custom_qa = CustomQAModel()
|
| 166 |
+
custom_embedding = CustomEmbeddingModel()
|
| 167 |
+
|
| 168 |
+
# Create a config with your custom models
|
| 169 |
+
custom_config = RetrievalAugmentationConfig(
|
| 170 |
+
summarization_model=custom_summarizer,
|
| 171 |
+
qa_model=custom_qa,
|
| 172 |
+
embedding_model=custom_embedding
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Initialize RAPTOR with your custom config
|
| 176 |
+
RA = RetrievalAugmentation(config=custom_config)
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
Check out `demo.ipynb` for examples on how to specify your own summarization/QA models, such as Llama/Mistral/Gemma, and Embedding Models such as SBERT, for use with RAPTOR.
|
| 180 |
+
|
| 181 |
+
Note: More examples and ways to configure RAPTOR are forthcoming. Advanced usage and additional features will be provided in the documentation and repository updates.
|
| 182 |
+
|
| 183 |
+
## Contributing
|
| 184 |
+
|
| 185 |
+
RAPTOR is an open-source project, and contributions are welcome. Whether you're fixing bugs, adding new features, or improving documentation, your help is appreciated.
|
| 186 |
+
|
| 187 |
+
## License
|
| 188 |
+
|
| 189 |
+
RAPTOR is released under the MIT License. See the LICENSE file in the repository for full details.
|
| 190 |
+
|
| 191 |
+
## Citation
|
| 192 |
+
|
| 193 |
+
If RAPTOR assists in your research, please cite it as follows:
|
| 194 |
+
|
| 195 |
+
```bibtex
|
| 196 |
+
@inproceedings{sarthi2024raptor,
|
| 197 |
+
title={RAPTOR: Recursive Abstractive Processing for Tree-Organized Retrieval},
|
| 198 |
+
author={Sarthi, Parth and Abdullah, Salman and Tuli, Aditi and Khanna, Shubh and Goldie, Anna and Manning, Christopher D.},
|
| 199 |
+
booktitle={International Conference on Learning Representations (ICLR)},
|
| 200 |
+
year={2024}
|
| 201 |
+
}
|
| 202 |
+
```
|
| 203 |
+
|
| 204 |
+
Stay tuned for more examples, configuration guides, and updates.
|
baselines/raptor/raptor/EmbeddingModels.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
|
| 4 |
+
from openai import OpenAI
|
| 5 |
+
from sentence_transformers import SentenceTransformer
|
| 6 |
+
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
| 7 |
+
|
| 8 |
+
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BaseEmbeddingModel(ABC):
|
| 12 |
+
@abstractmethod
|
| 13 |
+
def create_embedding(self, text):
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class OpenAIEmbeddingModel(BaseEmbeddingModel):
|
| 18 |
+
def __init__(self, model="text-embedding-ada-002"):
|
| 19 |
+
self.client = OpenAI()
|
| 20 |
+
self.model = model
|
| 21 |
+
|
| 22 |
+
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
| 23 |
+
def create_embedding(self, text):
|
| 24 |
+
text = text.replace("\n", " ")
|
| 25 |
+
return (
|
| 26 |
+
self.client.embeddings.create(input=[text], model=self.model)
|
| 27 |
+
.data[0]
|
| 28 |
+
.embedding
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class SBertEmbeddingModel(BaseEmbeddingModel):
|
| 33 |
+
def __init__(self, model_name="sentence-transformers/multi-qa-mpnet-base-cos-v1"):
|
| 34 |
+
self.model = SentenceTransformer(model_name)
|
| 35 |
+
|
| 36 |
+
def create_embedding(self, text):
|
| 37 |
+
return self.model.encode(text)
|
baselines/raptor/raptor/FaissRetriever.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from concurrent.futures import ProcessPoolExecutor
|
| 3 |
+
|
| 4 |
+
import faiss
|
| 5 |
+
import numpy as np
|
| 6 |
+
import tiktoken
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
from .EmbeddingModels import BaseEmbeddingModel, OpenAIEmbeddingModel
|
| 10 |
+
from .Retrievers import BaseRetriever
|
| 11 |
+
from .utils import split_text
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class FaissRetrieverConfig:
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
max_tokens=100,
|
| 18 |
+
max_context_tokens=3500,
|
| 19 |
+
use_top_k=False,
|
| 20 |
+
embedding_model=None,
|
| 21 |
+
question_embedding_model=None,
|
| 22 |
+
top_k=5,
|
| 23 |
+
tokenizer=tiktoken.get_encoding("cl100k_base"),
|
| 24 |
+
embedding_model_string=None,
|
| 25 |
+
):
|
| 26 |
+
if max_tokens < 1:
|
| 27 |
+
raise ValueError("max_tokens must be at least 1")
|
| 28 |
+
|
| 29 |
+
if top_k < 1:
|
| 30 |
+
raise ValueError("top_k must be at least 1")
|
| 31 |
+
|
| 32 |
+
if max_context_tokens is not None and max_context_tokens < 1:
|
| 33 |
+
raise ValueError("max_context_tokens must be at least 1 or None")
|
| 34 |
+
|
| 35 |
+
if embedding_model is not None and not isinstance(
|
| 36 |
+
embedding_model, BaseEmbeddingModel
|
| 37 |
+
):
|
| 38 |
+
raise ValueError(
|
| 39 |
+
"embedding_model must be an instance of BaseEmbeddingModel or None"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
if question_embedding_model is not None and not isinstance(
|
| 43 |
+
question_embedding_model, BaseEmbeddingModel
|
| 44 |
+
):
|
| 45 |
+
raise ValueError(
|
| 46 |
+
"question_embedding_model must be an instance of BaseEmbeddingModel or None"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
self.top_k = top_k
|
| 50 |
+
self.max_tokens = max_tokens
|
| 51 |
+
self.max_context_tokens = max_context_tokens
|
| 52 |
+
self.use_top_k = use_top_k
|
| 53 |
+
self.embedding_model = embedding_model or OpenAIEmbeddingModel()
|
| 54 |
+
self.question_embedding_model = question_embedding_model or self.embedding_model
|
| 55 |
+
self.tokenizer = tokenizer
|
| 56 |
+
self.embedding_model_string = embedding_model_string or "OpenAI"
|
| 57 |
+
|
| 58 |
+
def log_config(self):
|
| 59 |
+
config_summary = """
|
| 60 |
+
FaissRetrieverConfig:
|
| 61 |
+
Max Tokens: {max_tokens}
|
| 62 |
+
Max Context Tokens: {max_context_tokens}
|
| 63 |
+
Use Top K: {use_top_k}
|
| 64 |
+
Embedding Model: {embedding_model}
|
| 65 |
+
Question Embedding Model: {question_embedding_model}
|
| 66 |
+
Top K: {top_k}
|
| 67 |
+
Tokenizer: {tokenizer}
|
| 68 |
+
Embedding Model String: {embedding_model_string}
|
| 69 |
+
""".format(
|
| 70 |
+
max_tokens=self.max_tokens,
|
| 71 |
+
max_context_tokens=self.max_context_tokens,
|
| 72 |
+
use_top_k=self.use_top_k,
|
| 73 |
+
embedding_model=self.embedding_model,
|
| 74 |
+
question_embedding_model=self.question_embedding_model,
|
| 75 |
+
top_k=self.top_k,
|
| 76 |
+
tokenizer=self.tokenizer,
|
| 77 |
+
embedding_model_string=self.embedding_model_string,
|
| 78 |
+
)
|
| 79 |
+
return config_summary
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class FaissRetriever(BaseRetriever):
|
| 83 |
+
"""
|
| 84 |
+
FaissRetriever is a class that retrieves similar context chunks for a given query using Faiss.
|
| 85 |
+
encoders_type is 'same' if the question and context encoder is the same,
|
| 86 |
+
otherwise, encoders_type is 'different'.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def __init__(self, config):
|
| 90 |
+
self.embedding_model = config.embedding_model
|
| 91 |
+
self.question_embedding_model = config.question_embedding_model
|
| 92 |
+
self.index = None
|
| 93 |
+
self.context_chunks = None
|
| 94 |
+
self.max_tokens = config.max_tokens
|
| 95 |
+
self.max_context_tokens = config.max_context_tokens
|
| 96 |
+
self.use_top_k = config.use_top_k
|
| 97 |
+
self.tokenizer = config.tokenizer
|
| 98 |
+
self.top_k = config.top_k
|
| 99 |
+
self.embedding_model_string = config.embedding_model_string
|
| 100 |
+
|
| 101 |
+
def build_from_text(self, doc_text):
|
| 102 |
+
"""
|
| 103 |
+
Builds the index from a given text.
|
| 104 |
+
|
| 105 |
+
:param doc_text: A string containing the document text.
|
| 106 |
+
:param tokenizer: A tokenizer used to split the text into chunks.
|
| 107 |
+
:param max_tokens: An integer representing the maximum number of tokens per chunk.
|
| 108 |
+
"""
|
| 109 |
+
self.context_chunks = np.array(
|
| 110 |
+
split_text(doc_text, self.tokenizer, self.max_tokens)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
with ProcessPoolExecutor() as executor:
|
| 114 |
+
futures = [
|
| 115 |
+
executor.submit(self.embedding_model.create_embedding, context_chunk)
|
| 116 |
+
for context_chunk in self.context_chunks
|
| 117 |
+
]
|
| 118 |
+
|
| 119 |
+
self.embeddings = []
|
| 120 |
+
for future in tqdm(futures, total=len(futures), desc="Building embeddings"):
|
| 121 |
+
self.embeddings.append(future.result())
|
| 122 |
+
|
| 123 |
+
self.embeddings = np.array(self.embeddings, dtype=np.float32)
|
| 124 |
+
|
| 125 |
+
self.index = faiss.IndexFlatIP(self.embeddings.shape[1])
|
| 126 |
+
self.index.add(self.embeddings)
|
| 127 |
+
|
| 128 |
+
def build_from_leaf_nodes(self, leaf_nodes):
|
| 129 |
+
"""
|
| 130 |
+
Builds the index from a given text.
|
| 131 |
+
|
| 132 |
+
:param doc_text: A string containing the document text.
|
| 133 |
+
:param tokenizer: A tokenizer used to split the text into chunks.
|
| 134 |
+
:param max_tokens: An integer representing the maximum number of tokens per chunk.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
self.context_chunks = [node.text for node in leaf_nodes]
|
| 138 |
+
|
| 139 |
+
self.embeddings = np.array(
|
| 140 |
+
[node.embeddings[self.embedding_model_string] for node in leaf_nodes],
|
| 141 |
+
dtype=np.float32,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
self.index = faiss.IndexFlatIP(self.embeddings.shape[1])
|
| 145 |
+
self.index.add(self.embeddings)
|
| 146 |
+
|
| 147 |
+
def sanity_check(self, num_samples=4):
|
| 148 |
+
"""
|
| 149 |
+
Perform a sanity check by recomputing embeddings of a few randomly-selected chunks.
|
| 150 |
+
|
| 151 |
+
:param num_samples: The number of samples to test.
|
| 152 |
+
"""
|
| 153 |
+
indices = random.sample(range(len(self.context_chunks)), num_samples)
|
| 154 |
+
|
| 155 |
+
for i in indices:
|
| 156 |
+
original_embedding = self.embeddings[i]
|
| 157 |
+
recomputed_embedding = self.embedding_model.create_embedding(
|
| 158 |
+
self.context_chunks[i]
|
| 159 |
+
)
|
| 160 |
+
assert np.allclose(
|
| 161 |
+
original_embedding, recomputed_embedding
|
| 162 |
+
), f"Embeddings do not match for index {i}!"
|
| 163 |
+
|
| 164 |
+
print(f"Sanity check passed for {num_samples} random samples.")
|
| 165 |
+
|
| 166 |
+
def retrieve(self, query: str) -> str:
|
| 167 |
+
"""
|
| 168 |
+
Retrieves the k most similar context chunks for a given query.
|
| 169 |
+
|
| 170 |
+
:param query: A string containing the query.
|
| 171 |
+
:param k: An integer representing the number of similar context chunks to retrieve.
|
| 172 |
+
:return: A string containing the retrieved context chunks.
|
| 173 |
+
"""
|
| 174 |
+
query_embedding = np.array(
|
| 175 |
+
[
|
| 176 |
+
np.array(
|
| 177 |
+
self.question_embedding_model.create_embedding(query),
|
| 178 |
+
dtype=np.float32,
|
| 179 |
+
).squeeze()
|
| 180 |
+
]
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
context = ""
|
| 184 |
+
|
| 185 |
+
if self.use_top_k:
|
| 186 |
+
_, indices = self.index.search(query_embedding, self.top_k)
|
| 187 |
+
for i in range(self.top_k):
|
| 188 |
+
context += self.context_chunks[indices[0][i]]
|
| 189 |
+
|
| 190 |
+
else:
|
| 191 |
+
range_ = int(self.max_context_tokens / self.max_tokens)
|
| 192 |
+
_, indices = self.index.search(query_embedding, range_)
|
| 193 |
+
total_tokens = 0
|
| 194 |
+
for i in range(range_):
|
| 195 |
+
tokens = len(self.tokenizer.encode(self.context_chunks[indices[0][i]]))
|
| 196 |
+
context += self.context_chunks[indices[0][i]]
|
| 197 |
+
if total_tokens + tokens > self.max_context_tokens:
|
| 198 |
+
break
|
| 199 |
+
total_tokens += tokens
|
| 200 |
+
|
| 201 |
+
return context
|
baselines/raptor/raptor/QAModels.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from openai import OpenAI
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import getpass
|
| 8 |
+
from abc import ABC, abstractmethod
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
| 12 |
+
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BaseQAModel(ABC):
|
| 16 |
+
@abstractmethod
|
| 17 |
+
def answer_question(self, context, question):
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class GPT3QAModel(BaseQAModel):
|
| 22 |
+
def __init__(self, model="text-davinci-003"):
|
| 23 |
+
"""
|
| 24 |
+
Initializes the GPT-3 model with the specified model version.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
model (str, optional): The GPT-3 model version to use for generating summaries. Defaults to "text-davinci-003".
|
| 28 |
+
"""
|
| 29 |
+
self.model = model
|
| 30 |
+
self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
| 31 |
+
|
| 32 |
+
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
| 33 |
+
def answer_question(self, context, question, max_tokens=150, stop_sequence=None):
|
| 34 |
+
"""
|
| 35 |
+
Generates a summary of the given context using the GPT-3 model.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
context (str): The text to summarize.
|
| 39 |
+
max_tokens (int, optional): The maximum number of tokens in the generated summary. Defaults to 150.
|
| 40 |
+
stop_sequence (str, optional): The sequence at which to stop summarization. Defaults to None.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
str: The generated summary.
|
| 44 |
+
"""
|
| 45 |
+
try:
|
| 46 |
+
response = self.client.completions.create(
|
| 47 |
+
prompt=f"using the folloing information {context}. Answer the following question in less than 5-7 words, if possible: {question}",
|
| 48 |
+
temperature=0,
|
| 49 |
+
max_tokens=max_tokens,
|
| 50 |
+
top_p=1,
|
| 51 |
+
frequency_penalty=0,
|
| 52 |
+
presence_penalty=0,
|
| 53 |
+
stop=stop_sequence,
|
| 54 |
+
model=self.model,
|
| 55 |
+
)
|
| 56 |
+
return response.choices[0].text.strip()
|
| 57 |
+
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(e)
|
| 60 |
+
return ""
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class GPT3TurboQAModel(BaseQAModel):
|
| 64 |
+
def __init__(self, model="gpt-3.5-turbo"):
|
| 65 |
+
"""
|
| 66 |
+
Initializes the GPT-3 model with the specified model version.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
model (str, optional): The GPT-3 model version to use for generating summaries. Defaults to "text-davinci-003".
|
| 70 |
+
"""
|
| 71 |
+
self.model = model
|
| 72 |
+
self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
| 73 |
+
|
| 74 |
+
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
| 75 |
+
def _attempt_answer_question(
|
| 76 |
+
self, context, question, max_tokens=150, stop_sequence=None
|
| 77 |
+
):
|
| 78 |
+
"""
|
| 79 |
+
Generates a summary of the given context using the GPT-3 model.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
context (str): The text to summarize.
|
| 83 |
+
max_tokens (int, optional): The maximum number of tokens in the generated summary. Defaults to 150.
|
| 84 |
+
stop_sequence (str, optional): The sequence at which to stop summarization. Defaults to None.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
str: The generated summary.
|
| 88 |
+
"""
|
| 89 |
+
response = self.client.chat.completions.create(
|
| 90 |
+
model=self.model,
|
| 91 |
+
messages=[
|
| 92 |
+
{"role": "system", "content": "You are Question Answering Portal"},
|
| 93 |
+
{
|
| 94 |
+
"role": "user",
|
| 95 |
+
"content": f"Given Context: {context} Give the best full answer amongst the option to question {question}",
|
| 96 |
+
},
|
| 97 |
+
],
|
| 98 |
+
temperature=0,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
return response.choices[0].message.content.strip()
|
| 102 |
+
|
| 103 |
+
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
| 104 |
+
def answer_question(self, context, question, max_tokens=150, stop_sequence=None):
|
| 105 |
+
|
| 106 |
+
try:
|
| 107 |
+
return self._attempt_answer_question(
|
| 108 |
+
context, question, max_tokens=max_tokens, stop_sequence=stop_sequence
|
| 109 |
+
)
|
| 110 |
+
except Exception as e:
|
| 111 |
+
print(e)
|
| 112 |
+
return e
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class GPT4QAModel(BaseQAModel):
|
| 116 |
+
def __init__(self, model="gpt-4"):
|
| 117 |
+
"""
|
| 118 |
+
Initializes the GPT-3 model with the specified model version.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
model (str, optional): The GPT-3 model version to use for generating summaries. Defaults to "text-davinci-003".
|
| 122 |
+
"""
|
| 123 |
+
self.model = model
|
| 124 |
+
self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
| 125 |
+
|
| 126 |
+
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
| 127 |
+
def _attempt_answer_question(
|
| 128 |
+
self, context, question, max_tokens=150, stop_sequence=None
|
| 129 |
+
):
|
| 130 |
+
"""
|
| 131 |
+
Generates a summary of the given context using the GPT-3 model.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
context (str): The text to summarize.
|
| 135 |
+
max_tokens (int, optional): The maximum number of tokens in the generated summary. Defaults to 150.
|
| 136 |
+
stop_sequence (str, optional): The sequence at which to stop summarization. Defaults to None.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
str: The generated summary.
|
| 140 |
+
"""
|
| 141 |
+
response = self.client.chat.completions.create(
|
| 142 |
+
model=self.model,
|
| 143 |
+
messages=[
|
| 144 |
+
{"role": "system", "content": "You are Question Answering Portal"},
|
| 145 |
+
{
|
| 146 |
+
"role": "user",
|
| 147 |
+
"content": f"Given Context: {context} Give the best full answer amongst the option to question {question}",
|
| 148 |
+
},
|
| 149 |
+
],
|
| 150 |
+
temperature=0,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
return response.choices[0].message.content.strip()
|
| 154 |
+
|
| 155 |
+
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
| 156 |
+
def answer_question(self, context, question, max_tokens=150, stop_sequence=None):
|
| 157 |
+
|
| 158 |
+
try:
|
| 159 |
+
return self._attempt_answer_question(
|
| 160 |
+
context, question, max_tokens=max_tokens, stop_sequence=stop_sequence
|
| 161 |
+
)
|
| 162 |
+
except Exception as e:
|
| 163 |
+
print(e)
|
| 164 |
+
return e
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class UnifiedQAModel(BaseQAModel):
|
| 168 |
+
def __init__(self, model_name="allenai/unifiedqa-v2-t5-3b-1363200"):
|
| 169 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 170 |
+
self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(
|
| 171 |
+
self.device
|
| 172 |
+
)
|
| 173 |
+
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
|
| 174 |
+
|
| 175 |
+
def run_model(self, input_string, **generator_args):
|
| 176 |
+
input_ids = self.tokenizer.encode(input_string, return_tensors="pt").to(
|
| 177 |
+
self.device
|
| 178 |
+
)
|
| 179 |
+
res = self.model.generate(input_ids, **generator_args)
|
| 180 |
+
return self.tokenizer.batch_decode(res, skip_special_tokens=True)
|
| 181 |
+
|
| 182 |
+
def answer_question(self, context, question):
|
| 183 |
+
input_string = question + " \\n " + context
|
| 184 |
+
output = self.run_model(input_string)
|
| 185 |
+
return output[0]
|
baselines/raptor/raptor/RetrievalAugmentation.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import pickle
|
| 3 |
+
|
| 4 |
+
from .cluster_tree_builder import ClusterTreeBuilder, ClusterTreeConfig
|
| 5 |
+
from .EmbeddingModels import BaseEmbeddingModel
|
| 6 |
+
from .QAModels import BaseQAModel, GPT3TurboQAModel
|
| 7 |
+
from .SummarizationModels import BaseSummarizationModel
|
| 8 |
+
from .tree_builder import TreeBuilder, TreeBuilderConfig
|
| 9 |
+
from .tree_retriever import TreeRetriever, TreeRetrieverConfig
|
| 10 |
+
from .tree_structures import Node, Tree
|
| 11 |
+
|
| 12 |
+
# Define a dictionary to map supported tree builders to their respective configs
|
| 13 |
+
supported_tree_builders = {"cluster": (ClusterTreeBuilder, ClusterTreeConfig)}
|
| 14 |
+
|
| 15 |
+
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class RetrievalAugmentationConfig:
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
tree_builder_config=None,
|
| 22 |
+
tree_retriever_config=None, # Change from default instantiation
|
| 23 |
+
qa_model=None,
|
| 24 |
+
embedding_model=None,
|
| 25 |
+
summarization_model=None,
|
| 26 |
+
tree_builder_type="cluster",
|
| 27 |
+
# New parameters for TreeRetrieverConfig and TreeBuilderConfig
|
| 28 |
+
# TreeRetrieverConfig arguments
|
| 29 |
+
tr_tokenizer=None,
|
| 30 |
+
tr_threshold=0.5,
|
| 31 |
+
tr_top_k=5,
|
| 32 |
+
tr_selection_mode="top_k",
|
| 33 |
+
tr_context_embedding_model="OpenAI",
|
| 34 |
+
tr_embedding_model=None,
|
| 35 |
+
tr_num_layers=None,
|
| 36 |
+
tr_start_layer=None,
|
| 37 |
+
# TreeBuilderConfig arguments
|
| 38 |
+
tb_tokenizer=None,
|
| 39 |
+
tb_max_tokens=100,
|
| 40 |
+
tb_num_layers=5,
|
| 41 |
+
tb_threshold=0.5,
|
| 42 |
+
tb_top_k=5,
|
| 43 |
+
tb_selection_mode="top_k",
|
| 44 |
+
tb_summarization_length=100,
|
| 45 |
+
tb_summarization_model=None,
|
| 46 |
+
tb_embedding_models=None,
|
| 47 |
+
tb_cluster_embedding_model="OpenAI",
|
| 48 |
+
):
|
| 49 |
+
# Validate tree_builder_type
|
| 50 |
+
if tree_builder_type not in supported_tree_builders:
|
| 51 |
+
raise ValueError(
|
| 52 |
+
f"tree_builder_type must be one of {list(supported_tree_builders.keys())}"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Validate qa_model
|
| 56 |
+
if qa_model is not None and not isinstance(qa_model, BaseQAModel):
|
| 57 |
+
raise ValueError("qa_model must be an instance of BaseQAModel")
|
| 58 |
+
|
| 59 |
+
if embedding_model is not None and not isinstance(
|
| 60 |
+
embedding_model, BaseEmbeddingModel
|
| 61 |
+
):
|
| 62 |
+
raise ValueError(
|
| 63 |
+
"embedding_model must be an instance of BaseEmbeddingModel"
|
| 64 |
+
)
|
| 65 |
+
elif embedding_model is not None:
|
| 66 |
+
if tb_embedding_models is not None:
|
| 67 |
+
raise ValueError(
|
| 68 |
+
"Only one of 'tb_embedding_models' or 'embedding_model' should be provided, not both."
|
| 69 |
+
)
|
| 70 |
+
tb_embedding_models = {"EMB": embedding_model}
|
| 71 |
+
tr_embedding_model = embedding_model
|
| 72 |
+
tb_cluster_embedding_model = "EMB"
|
| 73 |
+
tr_context_embedding_model = "EMB"
|
| 74 |
+
|
| 75 |
+
if summarization_model is not None and not isinstance(
|
| 76 |
+
summarization_model, BaseSummarizationModel
|
| 77 |
+
):
|
| 78 |
+
raise ValueError(
|
| 79 |
+
"summarization_model must be an instance of BaseSummarizationModel"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
elif summarization_model is not None:
|
| 83 |
+
if tb_summarization_model is not None:
|
| 84 |
+
raise ValueError(
|
| 85 |
+
"Only one of 'tb_summarization_model' or 'summarization_model' should be provided, not both."
|
| 86 |
+
)
|
| 87 |
+
tb_summarization_model = summarization_model
|
| 88 |
+
|
| 89 |
+
# Set TreeBuilderConfig
|
| 90 |
+
tree_builder_class, tree_builder_config_class = supported_tree_builders[
|
| 91 |
+
tree_builder_type
|
| 92 |
+
]
|
| 93 |
+
if tree_builder_config is None:
|
| 94 |
+
tree_builder_config = tree_builder_config_class(
|
| 95 |
+
tokenizer=tb_tokenizer,
|
| 96 |
+
max_tokens=tb_max_tokens,
|
| 97 |
+
num_layers=tb_num_layers,
|
| 98 |
+
threshold=tb_threshold,
|
| 99 |
+
top_k=tb_top_k,
|
| 100 |
+
selection_mode=tb_selection_mode,
|
| 101 |
+
summarization_length=tb_summarization_length,
|
| 102 |
+
summarization_model=tb_summarization_model,
|
| 103 |
+
embedding_models=tb_embedding_models,
|
| 104 |
+
cluster_embedding_model=tb_cluster_embedding_model,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
elif not isinstance(tree_builder_config, tree_builder_config_class):
|
| 108 |
+
raise ValueError(
|
| 109 |
+
f"tree_builder_config must be a direct instance of {tree_builder_config_class} for tree_builder_type '{tree_builder_type}'"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Set TreeRetrieverConfig
|
| 113 |
+
if tree_retriever_config is None:
|
| 114 |
+
tree_retriever_config = TreeRetrieverConfig(
|
| 115 |
+
tokenizer=tr_tokenizer,
|
| 116 |
+
threshold=tr_threshold,
|
| 117 |
+
top_k=tr_top_k,
|
| 118 |
+
selection_mode=tr_selection_mode,
|
| 119 |
+
context_embedding_model=tr_context_embedding_model,
|
| 120 |
+
embedding_model=tr_embedding_model,
|
| 121 |
+
num_layers=tr_num_layers,
|
| 122 |
+
start_layer=tr_start_layer,
|
| 123 |
+
)
|
| 124 |
+
elif not isinstance(tree_retriever_config, TreeRetrieverConfig):
|
| 125 |
+
raise ValueError(
|
| 126 |
+
"tree_retriever_config must be an instance of TreeRetrieverConfig"
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Assign the created configurations to the instance
|
| 130 |
+
self.tree_builder_config = tree_builder_config
|
| 131 |
+
self.tree_retriever_config = tree_retriever_config
|
| 132 |
+
self.qa_model = qa_model or GPT3TurboQAModel()
|
| 133 |
+
self.tree_builder_type = tree_builder_type
|
| 134 |
+
|
| 135 |
+
def log_config(self):
|
| 136 |
+
config_summary = """
|
| 137 |
+
RetrievalAugmentationConfig:
|
| 138 |
+
{tree_builder_config}
|
| 139 |
+
|
| 140 |
+
{tree_retriever_config}
|
| 141 |
+
|
| 142 |
+
QA Model: {qa_model}
|
| 143 |
+
Tree Builder Type: {tree_builder_type}
|
| 144 |
+
""".format(
|
| 145 |
+
tree_builder_config=self.tree_builder_config.log_config(),
|
| 146 |
+
tree_retriever_config=self.tree_retriever_config.log_config(),
|
| 147 |
+
qa_model=self.qa_model,
|
| 148 |
+
tree_builder_type=self.tree_builder_type,
|
| 149 |
+
)
|
| 150 |
+
return config_summary
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class RetrievalAugmentation:
|
| 154 |
+
"""
|
| 155 |
+
A Retrieval Augmentation class that combines the TreeBuilder and TreeRetriever classes.
|
| 156 |
+
Enables adding documents to the tree, retrieving information, and answering questions.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
def __init__(self, config=None, tree=None):
|
| 160 |
+
"""
|
| 161 |
+
Initializes a RetrievalAugmentation instance with the specified configuration.
|
| 162 |
+
Args:
|
| 163 |
+
config (RetrievalAugmentationConfig): The configuration for the RetrievalAugmentation instance.
|
| 164 |
+
tree: The tree instance or the path to a pickled tree file.
|
| 165 |
+
"""
|
| 166 |
+
if config is None:
|
| 167 |
+
config = RetrievalAugmentationConfig()
|
| 168 |
+
if not isinstance(config, RetrievalAugmentationConfig):
|
| 169 |
+
raise ValueError(
|
| 170 |
+
"config must be an instance of RetrievalAugmentationConfig"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Check if tree is a string (indicating a path to a pickled tree)
|
| 174 |
+
if isinstance(tree, str):
|
| 175 |
+
try:
|
| 176 |
+
with open(tree, "rb") as file:
|
| 177 |
+
self.tree = pickle.load(file)
|
| 178 |
+
if not isinstance(self.tree, Tree):
|
| 179 |
+
raise ValueError("The loaded object is not an instance of Tree")
|
| 180 |
+
except Exception as e:
|
| 181 |
+
raise ValueError(f"Failed to load tree from {tree}: {e}")
|
| 182 |
+
elif isinstance(tree, Tree) or tree is None:
|
| 183 |
+
self.tree = tree
|
| 184 |
+
else:
|
| 185 |
+
raise ValueError(
|
| 186 |
+
"tree must be an instance of Tree, a path to a pickled Tree, or None"
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
tree_builder_class = supported_tree_builders[config.tree_builder_type][0]
|
| 190 |
+
self.tree_builder = tree_builder_class(config.tree_builder_config)
|
| 191 |
+
|
| 192 |
+
self.tree_retriever_config = config.tree_retriever_config
|
| 193 |
+
self.qa_model = config.qa_model
|
| 194 |
+
|
| 195 |
+
if self.tree is not None:
|
| 196 |
+
self.retriever = TreeRetriever(self.tree_retriever_config, self.tree)
|
| 197 |
+
else:
|
| 198 |
+
self.retriever = None
|
| 199 |
+
|
| 200 |
+
logging.info(
|
| 201 |
+
f"Successfully initialized RetrievalAugmentation with Config {config.log_config()}"
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
def add_documents(self, docs):
|
| 205 |
+
"""
|
| 206 |
+
Adds documents to the tree and creates a TreeRetriever instance.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
docs (str): The input text to add to the tree.
|
| 210 |
+
"""
|
| 211 |
+
if self.tree is not None:
|
| 212 |
+
user_input = input(
|
| 213 |
+
"Warning: Overwriting existing tree. Did you mean to call 'add_to_existing' instead? (y/n): "
|
| 214 |
+
)
|
| 215 |
+
if user_input.lower() == "y":
|
| 216 |
+
# self.add_to_existing(docs)
|
| 217 |
+
return
|
| 218 |
+
|
| 219 |
+
self.tree = self.tree_builder.build_from_text(text=docs)
|
| 220 |
+
self.retriever = TreeRetriever(self.tree_retriever_config, self.tree)
|
| 221 |
+
|
| 222 |
+
def retrieve(
|
| 223 |
+
self,
|
| 224 |
+
question,
|
| 225 |
+
start_layer: int = None,
|
| 226 |
+
num_layers: int = None,
|
| 227 |
+
top_k: int = 10,
|
| 228 |
+
max_tokens: int = 3500,
|
| 229 |
+
collapse_tree: bool = True,
|
| 230 |
+
return_layer_information: bool = True,
|
| 231 |
+
):
|
| 232 |
+
"""
|
| 233 |
+
Retrieves information and answers a question using the TreeRetriever instance.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
question (str): The question to answer.
|
| 237 |
+
start_layer (int): The layer to start from. Defaults to self.start_layer.
|
| 238 |
+
num_layers (int): The number of layers to traverse. Defaults to self.num_layers.
|
| 239 |
+
max_tokens (int): The maximum number of tokens. Defaults to 3500.
|
| 240 |
+
use_all_information (bool): Whether to retrieve information from all nodes. Defaults to False.
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
str: The context from which the answer can be found.
|
| 244 |
+
|
| 245 |
+
Raises:
|
| 246 |
+
ValueError: If the TreeRetriever instance has not been initialized.
|
| 247 |
+
"""
|
| 248 |
+
if self.retriever is None:
|
| 249 |
+
raise ValueError(
|
| 250 |
+
"The TreeRetriever instance has not been initialized. Call 'add_documents' first."
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
return self.retriever.retrieve(
|
| 254 |
+
question,
|
| 255 |
+
start_layer,
|
| 256 |
+
num_layers,
|
| 257 |
+
top_k,
|
| 258 |
+
max_tokens,
|
| 259 |
+
collapse_tree,
|
| 260 |
+
return_layer_information,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
def answer_question(
|
| 264 |
+
self,
|
| 265 |
+
question,
|
| 266 |
+
top_k: int = 10,
|
| 267 |
+
start_layer: int = None,
|
| 268 |
+
num_layers: int = None,
|
| 269 |
+
max_tokens: int = 3500,
|
| 270 |
+
collapse_tree: bool = True,
|
| 271 |
+
return_layer_information: bool = False,
|
| 272 |
+
):
|
| 273 |
+
"""
|
| 274 |
+
Retrieves information and answers a question using the TreeRetriever instance.
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
question (str): The question to answer.
|
| 278 |
+
start_layer (int): The layer to start from. Defaults to self.start_layer.
|
| 279 |
+
num_layers (int): The number of layers to traverse. Defaults to self.num_layers.
|
| 280 |
+
max_tokens (int): The maximum number of tokens. Defaults to 3500.
|
| 281 |
+
use_all_information (bool): Whether to retrieve information from all nodes. Defaults to False.
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
str: The answer to the question.
|
| 285 |
+
|
| 286 |
+
Raises:
|
| 287 |
+
ValueError: If the TreeRetriever instance has not been initialized.
|
| 288 |
+
"""
|
| 289 |
+
# if return_layer_information:
|
| 290 |
+
context, layer_information = self.retrieve(
|
| 291 |
+
question, start_layer, num_layers, top_k, max_tokens, collapse_tree, True
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
answer = self.qa_model.answer_question(context, question)
|
| 295 |
+
|
| 296 |
+
if return_layer_information:
|
| 297 |
+
return answer, layer_information
|
| 298 |
+
|
| 299 |
+
return answer
|
| 300 |
+
|
| 301 |
+
def save(self, path):
|
| 302 |
+
if self.tree is None:
|
| 303 |
+
raise ValueError("There is no tree to save.")
|
| 304 |
+
with open(path, "wb") as file:
|
| 305 |
+
pickle.dump(self.tree, file)
|
| 306 |
+
logging.info(f"Tree successfully saved to {path}")
|
baselines/raptor/raptor/Retrievers.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class BaseRetriever(ABC):
|
| 6 |
+
@abstractmethod
|
| 7 |
+
def retrieve(self, query: str) -> str:
|
| 8 |
+
pass
|
baselines/raptor/raptor/SummarizationModels.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
|
| 5 |
+
from openai import OpenAI
|
| 6 |
+
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
| 7 |
+
|
| 8 |
+
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BaseSummarizationModel(ABC):
|
| 12 |
+
@abstractmethod
|
| 13 |
+
def summarize(self, context, max_tokens=150):
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class GPT3TurboSummarizationModel(BaseSummarizationModel):
|
| 18 |
+
def __init__(self, model="gpt-3.5-turbo"):
|
| 19 |
+
|
| 20 |
+
self.model = model
|
| 21 |
+
|
| 22 |
+
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
| 23 |
+
def summarize(self, context, max_tokens=500, stop_sequence=None):
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
client = OpenAI()
|
| 27 |
+
|
| 28 |
+
response = client.chat.completions.create(
|
| 29 |
+
model=self.model,
|
| 30 |
+
messages=[
|
| 31 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 32 |
+
{
|
| 33 |
+
"role": "user",
|
| 34 |
+
"content": f"Write a summary of the following, including as many key details as possible: {context}:",
|
| 35 |
+
},
|
| 36 |
+
],
|
| 37 |
+
max_tokens=max_tokens,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
return response.choices[0].message.content
|
| 41 |
+
|
| 42 |
+
except Exception as e:
|
| 43 |
+
print(e)
|
| 44 |
+
return e
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class GPT3SummarizationModel(BaseSummarizationModel):
|
| 48 |
+
def __init__(self, model="text-davinci-003"):
|
| 49 |
+
|
| 50 |
+
self.model = model
|
| 51 |
+
|
| 52 |
+
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
| 53 |
+
def summarize(self, context, max_tokens=500, stop_sequence=None):
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
client = OpenAI()
|
| 57 |
+
|
| 58 |
+
response = client.chat.completions.create(
|
| 59 |
+
model=self.model,
|
| 60 |
+
messages=[
|
| 61 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 62 |
+
{
|
| 63 |
+
"role": "user",
|
| 64 |
+
"content": f"Write a summary of the following, including as many key details as possible: {context}:",
|
| 65 |
+
},
|
| 66 |
+
],
|
| 67 |
+
max_tokens=max_tokens,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
return response.choices[0].message.content
|
| 71 |
+
|
| 72 |
+
except Exception as e:
|
| 73 |
+
print(e)
|
| 74 |
+
return e
|
baselines/raptor/raptor/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# raptor/__init__.py
|
| 2 |
+
from .cluster_tree_builder import ClusterTreeBuilder, ClusterTreeConfig
|
| 3 |
+
from .EmbeddingModels import (BaseEmbeddingModel, OpenAIEmbeddingModel,
|
| 4 |
+
SBertEmbeddingModel)
|
| 5 |
+
from .FaissRetriever import FaissRetriever, FaissRetrieverConfig
|
| 6 |
+
from .QAModels import (BaseQAModel, GPT3QAModel, GPT3TurboQAModel, GPT4QAModel,
|
| 7 |
+
UnifiedQAModel)
|
| 8 |
+
from .RetrievalAugmentation import (RetrievalAugmentation,
|
| 9 |
+
RetrievalAugmentationConfig)
|
| 10 |
+
from .Retrievers import BaseRetriever
|
| 11 |
+
from .SummarizationModels import (BaseSummarizationModel,
|
| 12 |
+
GPT3SummarizationModel,
|
| 13 |
+
GPT3TurboSummarizationModel)
|
| 14 |
+
from .tree_builder import TreeBuilder, TreeBuilderConfig
|
| 15 |
+
from .tree_retriever import TreeRetriever, TreeRetrieverConfig
|
| 16 |
+
from .tree_structures import Node, Tree
|
baselines/raptor/raptor/cluster_tree_builder.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import pickle
|
| 3 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 4 |
+
from threading import Lock
|
| 5 |
+
from typing import Dict, List, Set
|
| 6 |
+
|
| 7 |
+
from .cluster_utils import ClusteringAlgorithm, RAPTOR_Clustering
|
| 8 |
+
from .tree_builder import TreeBuilder, TreeBuilderConfig
|
| 9 |
+
from .tree_structures import Node, Tree
|
| 10 |
+
from .utils import (distances_from_embeddings, get_children, get_embeddings,
|
| 11 |
+
get_node_list, get_text,
|
| 12 |
+
indices_of_nearest_neighbors_from_distances, split_text)
|
| 13 |
+
|
| 14 |
+
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ClusterTreeConfig(TreeBuilderConfig):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
reduction_dimension=10,
|
| 21 |
+
clustering_algorithm=RAPTOR_Clustering, # Default to RAPTOR clustering
|
| 22 |
+
clustering_params={}, # Pass additional params as a dict
|
| 23 |
+
*args,
|
| 24 |
+
**kwargs,
|
| 25 |
+
):
|
| 26 |
+
super().__init__(*args, **kwargs)
|
| 27 |
+
self.reduction_dimension = reduction_dimension
|
| 28 |
+
self.clustering_algorithm = clustering_algorithm
|
| 29 |
+
self.clustering_params = clustering_params
|
| 30 |
+
|
| 31 |
+
def log_config(self):
|
| 32 |
+
base_summary = super().log_config()
|
| 33 |
+
cluster_tree_summary = f"""
|
| 34 |
+
Reduction Dimension: {self.reduction_dimension}
|
| 35 |
+
Clustering Algorithm: {self.clustering_algorithm.__name__}
|
| 36 |
+
Clustering Parameters: {self.clustering_params}
|
| 37 |
+
"""
|
| 38 |
+
return base_summary + cluster_tree_summary
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ClusterTreeBuilder(TreeBuilder):
|
| 42 |
+
def __init__(self, config) -> None:
|
| 43 |
+
super().__init__(config)
|
| 44 |
+
|
| 45 |
+
if not isinstance(config, ClusterTreeConfig):
|
| 46 |
+
raise ValueError("config must be an instance of ClusterTreeConfig")
|
| 47 |
+
self.reduction_dimension = config.reduction_dimension
|
| 48 |
+
self.clustering_algorithm = config.clustering_algorithm
|
| 49 |
+
self.clustering_params = config.clustering_params
|
| 50 |
+
|
| 51 |
+
logging.info(
|
| 52 |
+
f"Successfully initialized ClusterTreeBuilder with Config {config.log_config()}"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def construct_tree(
|
| 56 |
+
self,
|
| 57 |
+
current_level_nodes: Dict[int, Node],
|
| 58 |
+
all_tree_nodes: Dict[int, Node],
|
| 59 |
+
layer_to_nodes: Dict[int, List[Node]],
|
| 60 |
+
use_multithreading: bool = False,
|
| 61 |
+
) -> Dict[int, Node]:
|
| 62 |
+
logging.info("Using Cluster TreeBuilder")
|
| 63 |
+
|
| 64 |
+
next_node_index = len(all_tree_nodes)
|
| 65 |
+
|
| 66 |
+
def process_cluster(
|
| 67 |
+
cluster, new_level_nodes, next_node_index, summarization_length, lock
|
| 68 |
+
):
|
| 69 |
+
node_texts = get_text(cluster)
|
| 70 |
+
|
| 71 |
+
summarized_text = self.summarize(
|
| 72 |
+
context=node_texts,
|
| 73 |
+
max_tokens=summarization_length,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
logging.info(
|
| 77 |
+
f"Node Texts Length: {len(self.tokenizer.encode(node_texts))}, Summarized Text Length: {len(self.tokenizer.encode(summarized_text))}"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
__, new_parent_node = self.create_node(
|
| 81 |
+
next_node_index, summarized_text, {node.index for node in cluster}
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
with lock:
|
| 85 |
+
new_level_nodes[next_node_index] = new_parent_node
|
| 86 |
+
|
| 87 |
+
for layer in range(self.num_layers):
|
| 88 |
+
|
| 89 |
+
new_level_nodes = {}
|
| 90 |
+
|
| 91 |
+
logging.info(f"Constructing Layer {layer}")
|
| 92 |
+
|
| 93 |
+
node_list_current_layer = get_node_list(current_level_nodes)
|
| 94 |
+
|
| 95 |
+
if len(node_list_current_layer) <= self.reduction_dimension + 1:
|
| 96 |
+
self.num_layers = layer
|
| 97 |
+
logging.info(
|
| 98 |
+
f"Stopping Layer construction: Cannot Create More Layers. Total Layers in tree: {layer}"
|
| 99 |
+
)
|
| 100 |
+
break
|
| 101 |
+
|
| 102 |
+
clusters = self.clustering_algorithm.perform_clustering(
|
| 103 |
+
node_list_current_layer,
|
| 104 |
+
self.cluster_embedding_model,
|
| 105 |
+
reduction_dimension=self.reduction_dimension,
|
| 106 |
+
**self.clustering_params,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
lock = Lock()
|
| 110 |
+
|
| 111 |
+
summarization_length = self.summarization_length
|
| 112 |
+
logging.info(f"Summarization Length: {summarization_length}")
|
| 113 |
+
|
| 114 |
+
if use_multithreading:
|
| 115 |
+
with ThreadPoolExecutor() as executor:
|
| 116 |
+
for cluster in clusters:
|
| 117 |
+
executor.submit(
|
| 118 |
+
process_cluster,
|
| 119 |
+
cluster,
|
| 120 |
+
new_level_nodes,
|
| 121 |
+
next_node_index,
|
| 122 |
+
summarization_length,
|
| 123 |
+
lock,
|
| 124 |
+
)
|
| 125 |
+
next_node_index += 1
|
| 126 |
+
executor.shutdown(wait=True)
|
| 127 |
+
|
| 128 |
+
else:
|
| 129 |
+
for cluster in clusters:
|
| 130 |
+
process_cluster(
|
| 131 |
+
cluster,
|
| 132 |
+
new_level_nodes,
|
| 133 |
+
next_node_index,
|
| 134 |
+
summarization_length,
|
| 135 |
+
lock,
|
| 136 |
+
)
|
| 137 |
+
next_node_index += 1
|
| 138 |
+
|
| 139 |
+
layer_to_nodes[layer + 1] = list(new_level_nodes.values())
|
| 140 |
+
current_level_nodes = new_level_nodes
|
| 141 |
+
all_tree_nodes.update(new_level_nodes)
|
| 142 |
+
|
| 143 |
+
tree = Tree(
|
| 144 |
+
all_tree_nodes,
|
| 145 |
+
layer_to_nodes[layer + 1],
|
| 146 |
+
layer_to_nodes[0],
|
| 147 |
+
layer + 1,
|
| 148 |
+
layer_to_nodes,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
return current_level_nodes
|
baselines/raptor/raptor/cluster_utils.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import random
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import tiktoken
|
| 8 |
+
import umap
|
| 9 |
+
from sklearn.mixture import GaussianMixture
|
| 10 |
+
|
| 11 |
+
# Initialize logging
|
| 12 |
+
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
|
| 13 |
+
|
| 14 |
+
from .tree_structures import Node
|
| 15 |
+
# Import necessary methods from other modules
|
| 16 |
+
from .utils import get_embeddings
|
| 17 |
+
|
| 18 |
+
# Set a random seed for reproducibility
|
| 19 |
+
RANDOM_SEED = 224
|
| 20 |
+
random.seed(RANDOM_SEED)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def global_cluster_embeddings(
|
| 24 |
+
embeddings: np.ndarray,
|
| 25 |
+
dim: int,
|
| 26 |
+
n_neighbors: Optional[int] = None,
|
| 27 |
+
metric: str = "cosine",
|
| 28 |
+
) -> np.ndarray:
|
| 29 |
+
if n_neighbors is None:
|
| 30 |
+
n_neighbors = int((len(embeddings) - 1) ** 0.5)
|
| 31 |
+
reduced_embeddings = umap.UMAP(
|
| 32 |
+
n_neighbors=n_neighbors, n_components=dim, metric=metric
|
| 33 |
+
).fit_transform(embeddings)
|
| 34 |
+
return reduced_embeddings
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def local_cluster_embeddings(
|
| 38 |
+
embeddings: np.ndarray, dim: int, num_neighbors: int = 10, metric: str = "cosine"
|
| 39 |
+
) -> np.ndarray:
|
| 40 |
+
reduced_embeddings = umap.UMAP(
|
| 41 |
+
n_neighbors=num_neighbors, n_components=dim, metric=metric
|
| 42 |
+
).fit_transform(embeddings)
|
| 43 |
+
return reduced_embeddings
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_optimal_clusters(
|
| 47 |
+
embeddings: np.ndarray, max_clusters: int = 50, random_state: int = RANDOM_SEED
|
| 48 |
+
) -> int:
|
| 49 |
+
max_clusters = min(max_clusters, len(embeddings))
|
| 50 |
+
n_clusters = np.arange(1, max_clusters)
|
| 51 |
+
bics = []
|
| 52 |
+
for n in n_clusters:
|
| 53 |
+
gm = GaussianMixture(n_components=n, random_state=random_state)
|
| 54 |
+
gm.fit(embeddings)
|
| 55 |
+
bics.append(gm.bic(embeddings))
|
| 56 |
+
optimal_clusters = n_clusters[np.argmin(bics)]
|
| 57 |
+
return optimal_clusters
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def GMM_cluster(embeddings: np.ndarray, threshold: float, random_state: int = 0):
|
| 61 |
+
n_clusters = get_optimal_clusters(embeddings)
|
| 62 |
+
gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
|
| 63 |
+
gm.fit(embeddings)
|
| 64 |
+
probs = gm.predict_proba(embeddings)
|
| 65 |
+
labels = [np.where(prob > threshold)[0] for prob in probs]
|
| 66 |
+
return labels, n_clusters
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def perform_clustering(
|
| 70 |
+
embeddings: np.ndarray, dim: int, threshold: float, verbose: bool = False
|
| 71 |
+
) -> List[np.ndarray]:
|
| 72 |
+
reduced_embeddings_global = global_cluster_embeddings(embeddings, min(dim, len(embeddings) -2))
|
| 73 |
+
global_clusters, n_global_clusters = GMM_cluster(
|
| 74 |
+
reduced_embeddings_global, threshold
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
if verbose:
|
| 78 |
+
logging.info(f"Global Clusters: {n_global_clusters}")
|
| 79 |
+
|
| 80 |
+
all_local_clusters = [np.array([]) for _ in range(len(embeddings))]
|
| 81 |
+
total_clusters = 0
|
| 82 |
+
|
| 83 |
+
for i in range(n_global_clusters):
|
| 84 |
+
global_cluster_embeddings_ = embeddings[
|
| 85 |
+
np.array([i in gc for gc in global_clusters])
|
| 86 |
+
]
|
| 87 |
+
if verbose:
|
| 88 |
+
logging.info(
|
| 89 |
+
f"Nodes in Global Cluster {i}: {len(global_cluster_embeddings_)}"
|
| 90 |
+
)
|
| 91 |
+
if len(global_cluster_embeddings_) == 0:
|
| 92 |
+
continue
|
| 93 |
+
if len(global_cluster_embeddings_) <= dim + 1:
|
| 94 |
+
local_clusters = [np.array([0]) for _ in global_cluster_embeddings_]
|
| 95 |
+
n_local_clusters = 1
|
| 96 |
+
else:
|
| 97 |
+
reduced_embeddings_local = local_cluster_embeddings(
|
| 98 |
+
global_cluster_embeddings_, dim
|
| 99 |
+
)
|
| 100 |
+
local_clusters, n_local_clusters = GMM_cluster(
|
| 101 |
+
reduced_embeddings_local, threshold
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
if verbose:
|
| 105 |
+
logging.info(f"Local Clusters in Global Cluster {i}: {n_local_clusters}")
|
| 106 |
+
|
| 107 |
+
for j in range(n_local_clusters):
|
| 108 |
+
local_cluster_embeddings_ = global_cluster_embeddings_[
|
| 109 |
+
np.array([j in lc for lc in local_clusters])
|
| 110 |
+
]
|
| 111 |
+
indices = np.where(
|
| 112 |
+
(embeddings == local_cluster_embeddings_[:, None]).all(-1)
|
| 113 |
+
)[1]
|
| 114 |
+
for idx in indices:
|
| 115 |
+
all_local_clusters[idx] = np.append(
|
| 116 |
+
all_local_clusters[idx], j + total_clusters
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
total_clusters += n_local_clusters
|
| 120 |
+
|
| 121 |
+
if verbose:
|
| 122 |
+
logging.info(f"Total Clusters: {total_clusters}")
|
| 123 |
+
return all_local_clusters
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class ClusteringAlgorithm(ABC):
|
| 127 |
+
@abstractmethod
|
| 128 |
+
def perform_clustering(self, embeddings: np.ndarray, **kwargs) -> List[List[int]]:
|
| 129 |
+
pass
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class RAPTOR_Clustering(ClusteringAlgorithm):
|
| 133 |
+
def perform_clustering(
|
| 134 |
+
nodes: List[Node],
|
| 135 |
+
embedding_model_name: str,
|
| 136 |
+
max_length_in_cluster: int = 3500,
|
| 137 |
+
tokenizer=tiktoken.get_encoding("cl100k_base"),
|
| 138 |
+
reduction_dimension: int = 10,
|
| 139 |
+
threshold: float = 0.1,
|
| 140 |
+
verbose: bool = False,
|
| 141 |
+
) -> List[List[Node]]:
|
| 142 |
+
# Get the embeddings from the nodes
|
| 143 |
+
embeddings = np.array([node.embeddings[embedding_model_name] for node in nodes])
|
| 144 |
+
|
| 145 |
+
# Perform the clustering
|
| 146 |
+
clusters = perform_clustering(
|
| 147 |
+
embeddings, dim=reduction_dimension, threshold=threshold
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Initialize an empty list to store the clusters of nodes
|
| 151 |
+
node_clusters = []
|
| 152 |
+
|
| 153 |
+
# Iterate over each unique label in the clusters
|
| 154 |
+
for label in np.unique(np.concatenate(clusters)):
|
| 155 |
+
# Get the indices of the nodes that belong to this cluster
|
| 156 |
+
indices = [i for i, cluster in enumerate(clusters) if label in cluster]
|
| 157 |
+
|
| 158 |
+
# Add the corresponding nodes to the node_clusters list
|
| 159 |
+
cluster_nodes = [nodes[i] for i in indices]
|
| 160 |
+
|
| 161 |
+
# Base case: if the cluster only has one node, do not attempt to recluster it
|
| 162 |
+
if len(cluster_nodes) == 1:
|
| 163 |
+
node_clusters.append(cluster_nodes)
|
| 164 |
+
continue
|
| 165 |
+
|
| 166 |
+
# Calculate the total length of the text in the nodes
|
| 167 |
+
total_length = sum(
|
| 168 |
+
[len(tokenizer.encode(node.text)) for node in cluster_nodes]
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# If the total length exceeds the maximum allowed length, recluster this cluster
|
| 172 |
+
if total_length > max_length_in_cluster:
|
| 173 |
+
if verbose:
|
| 174 |
+
logging.info(
|
| 175 |
+
f"reclustering cluster with {len(cluster_nodes)} nodes"
|
| 176 |
+
)
|
| 177 |
+
node_clusters.extend(
|
| 178 |
+
RAPTOR_Clustering.perform_clustering(
|
| 179 |
+
cluster_nodes, embedding_model_name, max_length_in_cluster
|
| 180 |
+
)
|
| 181 |
+
)
|
| 182 |
+
else:
|
| 183 |
+
node_clusters.append(cluster_nodes)
|
| 184 |
+
|
| 185 |
+
return node_clusters
|
baselines/raptor/raptor/tree_builder.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from abc import abstractclassmethod
|
| 5 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 6 |
+
from threading import Lock
|
| 7 |
+
from typing import Dict, List, Optional, Set, Tuple
|
| 8 |
+
|
| 9 |
+
import openai
|
| 10 |
+
import tiktoken
|
| 11 |
+
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
| 12 |
+
|
| 13 |
+
from .EmbeddingModels import BaseEmbeddingModel, OpenAIEmbeddingModel
|
| 14 |
+
from .SummarizationModels import (BaseSummarizationModel,
|
| 15 |
+
GPT3TurboSummarizationModel)
|
| 16 |
+
from .tree_structures import Node, Tree
|
| 17 |
+
from .utils import (distances_from_embeddings, get_children, get_embeddings,
|
| 18 |
+
get_node_list, get_text,
|
| 19 |
+
indices_of_nearest_neighbors_from_distances, split_text)
|
| 20 |
+
|
| 21 |
+
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TreeBuilderConfig:
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
tokenizer=None,
|
| 28 |
+
max_tokens=None,
|
| 29 |
+
num_layers=None,
|
| 30 |
+
threshold=None,
|
| 31 |
+
top_k=None,
|
| 32 |
+
selection_mode=None,
|
| 33 |
+
summarization_length=None,
|
| 34 |
+
summarization_model=None,
|
| 35 |
+
embedding_models=None,
|
| 36 |
+
cluster_embedding_model=None,
|
| 37 |
+
):
|
| 38 |
+
if tokenizer is None:
|
| 39 |
+
tokenizer = tiktoken.get_encoding("cl100k_base")
|
| 40 |
+
self.tokenizer = tokenizer
|
| 41 |
+
|
| 42 |
+
if max_tokens is None:
|
| 43 |
+
max_tokens = 100
|
| 44 |
+
if not isinstance(max_tokens, int) or max_tokens < 1:
|
| 45 |
+
raise ValueError("max_tokens must be an integer and at least 1")
|
| 46 |
+
self.max_tokens = max_tokens
|
| 47 |
+
|
| 48 |
+
if num_layers is None:
|
| 49 |
+
num_layers = 5
|
| 50 |
+
if not isinstance(num_layers, int) or num_layers < 1:
|
| 51 |
+
raise ValueError("num_layers must be an integer and at least 1")
|
| 52 |
+
self.num_layers = num_layers
|
| 53 |
+
|
| 54 |
+
if threshold is None:
|
| 55 |
+
threshold = 0.5
|
| 56 |
+
if not isinstance(threshold, (int, float)) or not (0 <= threshold <= 1):
|
| 57 |
+
raise ValueError("threshold must be a number between 0 and 1")
|
| 58 |
+
self.threshold = threshold
|
| 59 |
+
|
| 60 |
+
if top_k is None:
|
| 61 |
+
top_k = 5
|
| 62 |
+
if not isinstance(top_k, int) or top_k < 1:
|
| 63 |
+
raise ValueError("top_k must be an integer and at least 1")
|
| 64 |
+
self.top_k = top_k
|
| 65 |
+
|
| 66 |
+
if selection_mode is None:
|
| 67 |
+
selection_mode = "top_k"
|
| 68 |
+
if selection_mode not in ["top_k", "threshold"]:
|
| 69 |
+
raise ValueError("selection_mode must be either 'top_k' or 'threshold'")
|
| 70 |
+
self.selection_mode = selection_mode
|
| 71 |
+
|
| 72 |
+
if summarization_length is None:
|
| 73 |
+
summarization_length = 100
|
| 74 |
+
self.summarization_length = summarization_length
|
| 75 |
+
|
| 76 |
+
if summarization_model is None:
|
| 77 |
+
summarization_model = GPT3TurboSummarizationModel()
|
| 78 |
+
if not isinstance(summarization_model, BaseSummarizationModel):
|
| 79 |
+
raise ValueError(
|
| 80 |
+
"summarization_model must be an instance of BaseSummarizationModel"
|
| 81 |
+
)
|
| 82 |
+
self.summarization_model = summarization_model
|
| 83 |
+
|
| 84 |
+
if embedding_models is None:
|
| 85 |
+
embedding_models = {"OpenAI": OpenAIEmbeddingModel()}
|
| 86 |
+
if not isinstance(embedding_models, dict):
|
| 87 |
+
raise ValueError(
|
| 88 |
+
"embedding_models must be a dictionary of model_name: instance pairs"
|
| 89 |
+
)
|
| 90 |
+
for model in embedding_models.values():
|
| 91 |
+
if not isinstance(model, BaseEmbeddingModel):
|
| 92 |
+
raise ValueError(
|
| 93 |
+
"All embedding models must be an instance of BaseEmbeddingModel"
|
| 94 |
+
)
|
| 95 |
+
self.embedding_models = embedding_models
|
| 96 |
+
|
| 97 |
+
if cluster_embedding_model is None:
|
| 98 |
+
cluster_embedding_model = "OpenAI"
|
| 99 |
+
if cluster_embedding_model not in self.embedding_models:
|
| 100 |
+
raise ValueError(
|
| 101 |
+
"cluster_embedding_model must be a key in the embedding_models dictionary"
|
| 102 |
+
)
|
| 103 |
+
self.cluster_embedding_model = cluster_embedding_model
|
| 104 |
+
|
| 105 |
+
def log_config(self):
|
| 106 |
+
config_log = """
|
| 107 |
+
TreeBuilderConfig:
|
| 108 |
+
Tokenizer: {tokenizer}
|
| 109 |
+
Max Tokens: {max_tokens}
|
| 110 |
+
Num Layers: {num_layers}
|
| 111 |
+
Threshold: {threshold}
|
| 112 |
+
Top K: {top_k}
|
| 113 |
+
Selection Mode: {selection_mode}
|
| 114 |
+
Summarization Length: {summarization_length}
|
| 115 |
+
Summarization Model: {summarization_model}
|
| 116 |
+
Embedding Models: {embedding_models}
|
| 117 |
+
Cluster Embedding Model: {cluster_embedding_model}
|
| 118 |
+
""".format(
|
| 119 |
+
tokenizer=self.tokenizer,
|
| 120 |
+
max_tokens=self.max_tokens,
|
| 121 |
+
num_layers=self.num_layers,
|
| 122 |
+
threshold=self.threshold,
|
| 123 |
+
top_k=self.top_k,
|
| 124 |
+
selection_mode=self.selection_mode,
|
| 125 |
+
summarization_length=self.summarization_length,
|
| 126 |
+
summarization_model=self.summarization_model,
|
| 127 |
+
embedding_models=self.embedding_models,
|
| 128 |
+
cluster_embedding_model=self.cluster_embedding_model,
|
| 129 |
+
)
|
| 130 |
+
return config_log
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class TreeBuilder:
|
| 134 |
+
"""
|
| 135 |
+
The TreeBuilder class is responsible for building a hierarchical text abstraction
|
| 136 |
+
structure, known as a "tree," using summarization models and
|
| 137 |
+
embedding models.
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
def __init__(self, config) -> None:
|
| 141 |
+
"""Initializes the tokenizer, maximum tokens, number of layers, top-k value, threshold, and selection mode."""
|
| 142 |
+
|
| 143 |
+
self.tokenizer = config.tokenizer
|
| 144 |
+
self.max_tokens = config.max_tokens
|
| 145 |
+
self.num_layers = config.num_layers
|
| 146 |
+
self.top_k = config.top_k
|
| 147 |
+
self.threshold = config.threshold
|
| 148 |
+
self.selection_mode = config.selection_mode
|
| 149 |
+
self.summarization_length = config.summarization_length
|
| 150 |
+
self.summarization_model = config.summarization_model
|
| 151 |
+
self.embedding_models = config.embedding_models
|
| 152 |
+
self.cluster_embedding_model = config.cluster_embedding_model
|
| 153 |
+
|
| 154 |
+
logging.info(
|
| 155 |
+
f"Successfully initialized TreeBuilder with Config {config.log_config()}"
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def create_node(
|
| 159 |
+
self, index: int, text: str, children_indices: Optional[Set[int]] = None
|
| 160 |
+
) -> Tuple[int, Node]:
|
| 161 |
+
"""Creates a new node with the given index, text, and (optionally) children indices.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
index (int): The index of the new node.
|
| 165 |
+
text (str): The text associated with the new node.
|
| 166 |
+
children_indices (Optional[Set[int]]): A set of indices representing the children of the new node.
|
| 167 |
+
If not provided, an empty set will be used.
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
Tuple[int, Node]: A tuple containing the index and the newly created node.
|
| 171 |
+
"""
|
| 172 |
+
if children_indices is None:
|
| 173 |
+
children_indices = set()
|
| 174 |
+
|
| 175 |
+
embeddings = {
|
| 176 |
+
model_name: model.create_embedding(text)
|
| 177 |
+
for model_name, model in self.embedding_models.items()
|
| 178 |
+
}
|
| 179 |
+
return (index, Node(text, index, children_indices, embeddings))
|
| 180 |
+
|
| 181 |
+
def create_embedding(self, text) -> List[float]:
|
| 182 |
+
"""
|
| 183 |
+
Generates embeddings for the given text using the specified embedding model.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
text (str): The text for which to generate embeddings.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
List[float]: The generated embeddings.
|
| 190 |
+
"""
|
| 191 |
+
return self.embedding_models[self.cluster_embedding_model].create_embedding(
|
| 192 |
+
text
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
def summarize(self, context, max_tokens=150) -> str:
|
| 196 |
+
"""
|
| 197 |
+
Generates a summary of the input context using the specified summarization model.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
context (str, optional): The context to summarize.
|
| 201 |
+
max_tokens (int, optional): The maximum number of tokens in the generated summary. Defaults to 150.o
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
str: The generated summary.
|
| 205 |
+
"""
|
| 206 |
+
return self.summarization_model.summarize(context, max_tokens)
|
| 207 |
+
|
| 208 |
+
def get_relevant_nodes(self, current_node, list_nodes) -> List[Node]:
|
| 209 |
+
"""
|
| 210 |
+
Retrieves the top-k most relevant nodes to the current node from the list of nodes
|
| 211 |
+
based on cosine distance in the embedding space.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
current_node (Node): The current node.
|
| 215 |
+
list_nodes (List[Node]): The list of nodes.
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
List[Node]: The top-k most relevant nodes.
|
| 219 |
+
"""
|
| 220 |
+
embeddings = get_embeddings(list_nodes, self.cluster_embedding_model)
|
| 221 |
+
distances = distances_from_embeddings(
|
| 222 |
+
current_node.embeddings[self.cluster_embedding_model], embeddings
|
| 223 |
+
)
|
| 224 |
+
indices = indices_of_nearest_neighbors_from_distances(distances)
|
| 225 |
+
|
| 226 |
+
if self.selection_mode == "threshold":
|
| 227 |
+
best_indices = [
|
| 228 |
+
index for index in indices if distances[index] > self.threshold
|
| 229 |
+
]
|
| 230 |
+
|
| 231 |
+
elif self.selection_mode == "top_k":
|
| 232 |
+
best_indices = indices[: self.top_k]
|
| 233 |
+
|
| 234 |
+
nodes_to_add = [list_nodes[idx] for idx in best_indices]
|
| 235 |
+
|
| 236 |
+
return nodes_to_add
|
| 237 |
+
|
| 238 |
+
def multithreaded_create_leaf_nodes(self, chunks: List[str]) -> Dict[int, Node]:
|
| 239 |
+
"""Creates leaf nodes using multithreading from the given list of text chunks.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
chunks (List[str]): A list of text chunks to be turned into leaf nodes.
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
Dict[int, Node]: A dictionary mapping node indices to the corresponding leaf nodes.
|
| 246 |
+
"""
|
| 247 |
+
with ThreadPoolExecutor() as executor:
|
| 248 |
+
future_nodes = {
|
| 249 |
+
executor.submit(self.create_node, index, text): (index, text)
|
| 250 |
+
for index, text in enumerate(chunks)
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
leaf_nodes = {}
|
| 254 |
+
for future in as_completed(future_nodes):
|
| 255 |
+
index, node = future.result()
|
| 256 |
+
leaf_nodes[index] = node
|
| 257 |
+
|
| 258 |
+
return leaf_nodes
|
| 259 |
+
|
| 260 |
+
def build_from_text(self, text: str, use_multithreading: bool = True) -> Tree:
|
| 261 |
+
"""Builds a golden tree from the input text, optionally using multithreading.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
text (str): The input text.
|
| 265 |
+
use_multithreading (bool, optional): Whether to use multithreading when creating leaf nodes.
|
| 266 |
+
Default: True.
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
Tree: The golden tree structure.
|
| 270 |
+
"""
|
| 271 |
+
chunks = split_text(text, self.tokenizer, self.max_tokens)
|
| 272 |
+
|
| 273 |
+
logging.info("Creating Leaf Nodes")
|
| 274 |
+
|
| 275 |
+
if use_multithreading:
|
| 276 |
+
leaf_nodes = self.multithreaded_create_leaf_nodes(chunks)
|
| 277 |
+
else:
|
| 278 |
+
leaf_nodes = {}
|
| 279 |
+
for index, text in enumerate(chunks):
|
| 280 |
+
__, node = self.create_node(index, text)
|
| 281 |
+
leaf_nodes[index] = node
|
| 282 |
+
|
| 283 |
+
layer_to_nodes = {0: list(leaf_nodes.values())}
|
| 284 |
+
|
| 285 |
+
logging.info(f"Created {len(leaf_nodes)} Leaf Embeddings")
|
| 286 |
+
|
| 287 |
+
logging.info("Building All Nodes")
|
| 288 |
+
|
| 289 |
+
all_nodes = copy.deepcopy(leaf_nodes)
|
| 290 |
+
|
| 291 |
+
root_nodes = self.construct_tree(all_nodes, all_nodes, layer_to_nodes)
|
| 292 |
+
|
| 293 |
+
tree = Tree(all_nodes, root_nodes, leaf_nodes, self.num_layers, layer_to_nodes)
|
| 294 |
+
|
| 295 |
+
return tree
|
| 296 |
+
|
| 297 |
+
@abstractclassmethod
|
| 298 |
+
def construct_tree(
|
| 299 |
+
self,
|
| 300 |
+
current_level_nodes: Dict[int, Node],
|
| 301 |
+
all_tree_nodes: Dict[int, Node],
|
| 302 |
+
layer_to_nodes: Dict[int, List[Node]],
|
| 303 |
+
use_multithreading: bool = True,
|
| 304 |
+
) -> Dict[int, Node]:
|
| 305 |
+
"""
|
| 306 |
+
Constructs the hierarchical tree structure layer by layer by iteratively summarizing groups
|
| 307 |
+
of relevant nodes and updating the current_level_nodes and all_tree_nodes dictionaries at each step.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
current_level_nodes (Dict[int, Node]): The current set of nodes.
|
| 311 |
+
all_tree_nodes (Dict[int, Node]): The dictionary of all nodes.
|
| 312 |
+
use_multithreading (bool): Whether to use multithreading to speed up the process.
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
Dict[int, Node]: The final set of root nodes.
|
| 316 |
+
"""
|
| 317 |
+
pass
|
| 318 |
+
|
| 319 |
+
# logging.info("Using Transformer-like TreeBuilder")
|
| 320 |
+
|
| 321 |
+
# def process_node(idx, current_level_nodes, new_level_nodes, all_tree_nodes, next_node_index, lock):
|
| 322 |
+
# relevant_nodes_chunk = self.get_relevant_nodes(
|
| 323 |
+
# current_level_nodes[idx], current_level_nodes
|
| 324 |
+
# )
|
| 325 |
+
|
| 326 |
+
# node_texts = get_text(relevant_nodes_chunk)
|
| 327 |
+
|
| 328 |
+
# summarized_text = self.summarize(
|
| 329 |
+
# context=node_texts,
|
| 330 |
+
# max_tokens=self.summarization_length,
|
| 331 |
+
# )
|
| 332 |
+
|
| 333 |
+
# logging.info(
|
| 334 |
+
# f"Node Texts Length: {len(self.tokenizer.encode(node_texts))}, Summarized Text Length: {len(self.tokenizer.encode(summarized_text))}"
|
| 335 |
+
# )
|
| 336 |
+
|
| 337 |
+
# next_node_index, new_parent_node = self.create_node(
|
| 338 |
+
# next_node_index,
|
| 339 |
+
# summarized_text,
|
| 340 |
+
# {node.index for node in relevant_nodes_chunk}
|
| 341 |
+
# )
|
| 342 |
+
|
| 343 |
+
# with lock:
|
| 344 |
+
# new_level_nodes[next_node_index] = new_parent_node
|
| 345 |
+
|
| 346 |
+
# for layer in range(self.num_layers):
|
| 347 |
+
# logging.info(f"Constructing Layer {layer}: ")
|
| 348 |
+
|
| 349 |
+
# node_list_current_layer = get_node_list(current_level_nodes)
|
| 350 |
+
# next_node_index = len(all_tree_nodes)
|
| 351 |
+
|
| 352 |
+
# new_level_nodes = {}
|
| 353 |
+
# lock = Lock()
|
| 354 |
+
|
| 355 |
+
# if use_multithreading:
|
| 356 |
+
# with ThreadPoolExecutor() as executor:
|
| 357 |
+
# for idx in range(0, len(node_list_current_layer)):
|
| 358 |
+
# executor.submit(process_node, idx, node_list_current_layer, new_level_nodes, all_tree_nodes, next_node_index, lock)
|
| 359 |
+
# next_node_index += 1
|
| 360 |
+
# executor.shutdown(wait=True)
|
| 361 |
+
# else:
|
| 362 |
+
# for idx in range(0, len(node_list_current_layer)):
|
| 363 |
+
# process_node(idx, node_list_current_layer, new_level_nodes, all_tree_nodes, next_node_index, lock)
|
| 364 |
+
|
| 365 |
+
# layer_to_nodes[layer + 1] = list(new_level_nodes.values())
|
| 366 |
+
# current_level_nodes = new_level_nodes
|
| 367 |
+
# all_tree_nodes.update(new_level_nodes)
|
| 368 |
+
|
| 369 |
+
# return new_level_nodes
|
baselines/raptor/raptor/tree_retriever.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from typing import Dict, List, Set
|
| 4 |
+
|
| 5 |
+
import tiktoken
|
| 6 |
+
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
| 7 |
+
|
| 8 |
+
from .EmbeddingModels import BaseEmbeddingModel, OpenAIEmbeddingModel
|
| 9 |
+
from .Retrievers import BaseRetriever
|
| 10 |
+
from .tree_structures import Node, Tree
|
| 11 |
+
from .utils import (distances_from_embeddings, get_children, get_embeddings,
|
| 12 |
+
get_node_list, get_text,
|
| 13 |
+
indices_of_nearest_neighbors_from_distances,
|
| 14 |
+
reverse_mapping)
|
| 15 |
+
|
| 16 |
+
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TreeRetrieverConfig:
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
tokenizer=None,
|
| 23 |
+
threshold=None,
|
| 24 |
+
top_k=None,
|
| 25 |
+
selection_mode=None,
|
| 26 |
+
context_embedding_model=None,
|
| 27 |
+
embedding_model=None,
|
| 28 |
+
num_layers=None,
|
| 29 |
+
start_layer=None,
|
| 30 |
+
):
|
| 31 |
+
if tokenizer is None:
|
| 32 |
+
tokenizer = tiktoken.get_encoding("cl100k_base")
|
| 33 |
+
self.tokenizer = tokenizer
|
| 34 |
+
|
| 35 |
+
if threshold is None:
|
| 36 |
+
threshold = 0.5
|
| 37 |
+
if not isinstance(threshold, float) or not (0 <= threshold <= 1):
|
| 38 |
+
raise ValueError("threshold must be a float between 0 and 1")
|
| 39 |
+
self.threshold = threshold
|
| 40 |
+
|
| 41 |
+
if top_k is None:
|
| 42 |
+
top_k = 5
|
| 43 |
+
if not isinstance(top_k, int) or top_k < 1:
|
| 44 |
+
raise ValueError("top_k must be an integer and at least 1")
|
| 45 |
+
self.top_k = top_k
|
| 46 |
+
|
| 47 |
+
if selection_mode is None:
|
| 48 |
+
selection_mode = "top_k"
|
| 49 |
+
if not isinstance(selection_mode, str) or selection_mode not in [
|
| 50 |
+
"top_k",
|
| 51 |
+
"threshold",
|
| 52 |
+
]:
|
| 53 |
+
raise ValueError(
|
| 54 |
+
"selection_mode must be a string and either 'top_k' or 'threshold'"
|
| 55 |
+
)
|
| 56 |
+
self.selection_mode = selection_mode
|
| 57 |
+
|
| 58 |
+
if context_embedding_model is None:
|
| 59 |
+
context_embedding_model = "OpenAI"
|
| 60 |
+
if not isinstance(context_embedding_model, str):
|
| 61 |
+
raise ValueError("context_embedding_model must be a string")
|
| 62 |
+
self.context_embedding_model = context_embedding_model
|
| 63 |
+
|
| 64 |
+
if embedding_model is None:
|
| 65 |
+
embedding_model = OpenAIEmbeddingModel()
|
| 66 |
+
if not isinstance(embedding_model, BaseEmbeddingModel):
|
| 67 |
+
raise ValueError(
|
| 68 |
+
"embedding_model must be an instance of BaseEmbeddingModel"
|
| 69 |
+
)
|
| 70 |
+
self.embedding_model = embedding_model
|
| 71 |
+
|
| 72 |
+
if num_layers is not None:
|
| 73 |
+
if not isinstance(num_layers, int) or num_layers < 0:
|
| 74 |
+
raise ValueError("num_layers must be an integer and at least 0")
|
| 75 |
+
self.num_layers = num_layers
|
| 76 |
+
|
| 77 |
+
if start_layer is not None:
|
| 78 |
+
if not isinstance(start_layer, int) or start_layer < 0:
|
| 79 |
+
raise ValueError("start_layer must be an integer and at least 0")
|
| 80 |
+
self.start_layer = start_layer
|
| 81 |
+
|
| 82 |
+
def log_config(self):
|
| 83 |
+
config_log = """
|
| 84 |
+
TreeRetrieverConfig:
|
| 85 |
+
Tokenizer: {tokenizer}
|
| 86 |
+
Threshold: {threshold}
|
| 87 |
+
Top K: {top_k}
|
| 88 |
+
Selection Mode: {selection_mode}
|
| 89 |
+
Context Embedding Model: {context_embedding_model}
|
| 90 |
+
Embedding Model: {embedding_model}
|
| 91 |
+
Num Layers: {num_layers}
|
| 92 |
+
Start Layer: {start_layer}
|
| 93 |
+
""".format(
|
| 94 |
+
tokenizer=self.tokenizer,
|
| 95 |
+
threshold=self.threshold,
|
| 96 |
+
top_k=self.top_k,
|
| 97 |
+
selection_mode=self.selection_mode,
|
| 98 |
+
context_embedding_model=self.context_embedding_model,
|
| 99 |
+
embedding_model=self.embedding_model,
|
| 100 |
+
num_layers=self.num_layers,
|
| 101 |
+
start_layer=self.start_layer,
|
| 102 |
+
)
|
| 103 |
+
return config_log
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class TreeRetriever(BaseRetriever):
|
| 107 |
+
|
| 108 |
+
def __init__(self, config, tree) -> None:
|
| 109 |
+
if not isinstance(tree, Tree):
|
| 110 |
+
raise ValueError("tree must be an instance of Tree")
|
| 111 |
+
|
| 112 |
+
if config.num_layers is not None and config.num_layers > tree.num_layers + 1:
|
| 113 |
+
raise ValueError(
|
| 114 |
+
"num_layers in config must be less than or equal to tree.num_layers + 1"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
if config.start_layer is not None and config.start_layer > tree.num_layers:
|
| 118 |
+
raise ValueError(
|
| 119 |
+
"start_layer in config must be less than or equal to tree.num_layers"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
self.tree = tree
|
| 123 |
+
self.num_layers = (
|
| 124 |
+
config.num_layers if config.num_layers is not None else tree.num_layers + 1
|
| 125 |
+
)
|
| 126 |
+
self.start_layer = (
|
| 127 |
+
config.start_layer if config.start_layer is not None else tree.num_layers
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
if self.num_layers > self.start_layer + 1:
|
| 131 |
+
raise ValueError("num_layers must be less than or equal to start_layer + 1")
|
| 132 |
+
|
| 133 |
+
self.tokenizer = config.tokenizer
|
| 134 |
+
self.top_k = config.top_k
|
| 135 |
+
self.threshold = config.threshold
|
| 136 |
+
self.selection_mode = config.selection_mode
|
| 137 |
+
self.embedding_model = config.embedding_model
|
| 138 |
+
self.context_embedding_model = config.context_embedding_model
|
| 139 |
+
|
| 140 |
+
self.tree_node_index_to_layer = reverse_mapping(self.tree.layer_to_nodes)
|
| 141 |
+
|
| 142 |
+
logging.info(
|
| 143 |
+
f"Successfully initialized TreeRetriever with Config {config.log_config()}"
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
def create_embedding(self, text: str) -> List[float]:
|
| 147 |
+
"""
|
| 148 |
+
Generates embeddings for the given text using the specified embedding model.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
text (str): The text for which to generate embeddings.
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
List[float]: The generated embeddings.
|
| 155 |
+
"""
|
| 156 |
+
return self.embedding_model.create_embedding(text)
|
| 157 |
+
|
| 158 |
+
def retrieve_information_collapse_tree(self, query: str, top_k: int, max_tokens: int) -> str:
|
| 159 |
+
"""
|
| 160 |
+
Retrieves the most relevant information from the tree based on the query.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
query (str): The query text.
|
| 164 |
+
max_tokens (int): The maximum number of tokens.
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
str: The context created using the most relevant nodes.
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
query_embedding = self.create_embedding(query)
|
| 171 |
+
|
| 172 |
+
selected_nodes = []
|
| 173 |
+
|
| 174 |
+
node_list = get_node_list(self.tree.all_nodes)
|
| 175 |
+
|
| 176 |
+
embeddings = get_embeddings(node_list, self.context_embedding_model)
|
| 177 |
+
|
| 178 |
+
distances = distances_from_embeddings(query_embedding, embeddings)
|
| 179 |
+
|
| 180 |
+
indices = indices_of_nearest_neighbors_from_distances(distances)
|
| 181 |
+
|
| 182 |
+
total_tokens = 0
|
| 183 |
+
for idx in indices[:top_k]:
|
| 184 |
+
|
| 185 |
+
node = node_list[idx]
|
| 186 |
+
node_tokens = len(self.tokenizer.encode(node.text))
|
| 187 |
+
|
| 188 |
+
if total_tokens + node_tokens > max_tokens:
|
| 189 |
+
break
|
| 190 |
+
|
| 191 |
+
selected_nodes.append(node)
|
| 192 |
+
total_tokens += node_tokens
|
| 193 |
+
|
| 194 |
+
context = get_text(selected_nodes)
|
| 195 |
+
return selected_nodes, context
|
| 196 |
+
|
| 197 |
+
def retrieve_information(
|
| 198 |
+
self, current_nodes: List[Node], query: str, num_layers: int
|
| 199 |
+
) -> str:
|
| 200 |
+
"""
|
| 201 |
+
Retrieves the most relevant information from the tree based on the query.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
current_nodes (List[Node]): A List of the current nodes.
|
| 205 |
+
query (str): The query text.
|
| 206 |
+
num_layers (int): The number of layers to traverse.
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
str: The context created using the most relevant nodes.
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
query_embedding = self.create_embedding(query)
|
| 213 |
+
|
| 214 |
+
selected_nodes = []
|
| 215 |
+
|
| 216 |
+
node_list = current_nodes
|
| 217 |
+
|
| 218 |
+
for layer in range(num_layers):
|
| 219 |
+
|
| 220 |
+
embeddings = get_embeddings(node_list, self.context_embedding_model)
|
| 221 |
+
|
| 222 |
+
distances = distances_from_embeddings(query_embedding, embeddings)
|
| 223 |
+
|
| 224 |
+
indices = indices_of_nearest_neighbors_from_distances(distances)
|
| 225 |
+
|
| 226 |
+
if self.selection_mode == "threshold":
|
| 227 |
+
best_indices = [
|
| 228 |
+
index for index in indices if distances[index] > self.threshold
|
| 229 |
+
]
|
| 230 |
+
|
| 231 |
+
elif self.selection_mode == "top_k":
|
| 232 |
+
best_indices = indices[: self.top_k]
|
| 233 |
+
|
| 234 |
+
nodes_to_add = [node_list[idx] for idx in best_indices]
|
| 235 |
+
|
| 236 |
+
selected_nodes.extend(nodes_to_add)
|
| 237 |
+
|
| 238 |
+
if layer != num_layers - 1:
|
| 239 |
+
|
| 240 |
+
child_nodes = []
|
| 241 |
+
|
| 242 |
+
for index in best_indices:
|
| 243 |
+
child_nodes.extend(node_list[index].children)
|
| 244 |
+
|
| 245 |
+
# take the unique values
|
| 246 |
+
child_nodes = list(dict.fromkeys(child_nodes))
|
| 247 |
+
node_list = [self.tree.all_nodes[i] for i in child_nodes]
|
| 248 |
+
|
| 249 |
+
context = get_text(selected_nodes)
|
| 250 |
+
return selected_nodes, context
|
| 251 |
+
|
| 252 |
+
def retrieve(
|
| 253 |
+
self,
|
| 254 |
+
query: str,
|
| 255 |
+
start_layer: int = None,
|
| 256 |
+
num_layers: int = None,
|
| 257 |
+
top_k: int = 10,
|
| 258 |
+
max_tokens: int = 3500,
|
| 259 |
+
collapse_tree: bool = True,
|
| 260 |
+
return_layer_information: bool = False,
|
| 261 |
+
) -> str:
|
| 262 |
+
"""
|
| 263 |
+
Queries the tree and returns the most relevant information.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
query (str): The query text.
|
| 267 |
+
start_layer (int): The layer to start from. Defaults to self.start_layer.
|
| 268 |
+
num_layers (int): The number of layers to traverse. Defaults to self.num_layers.
|
| 269 |
+
max_tokens (int): The maximum number of tokens. Defaults to 3500.
|
| 270 |
+
collapse_tree (bool): Whether to retrieve information from all nodes. Defaults to False.
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
str: The result of the query.
|
| 274 |
+
"""
|
| 275 |
+
|
| 276 |
+
if not isinstance(query, str):
|
| 277 |
+
raise ValueError("query must be a string")
|
| 278 |
+
|
| 279 |
+
if not isinstance(max_tokens, int) or max_tokens < 1:
|
| 280 |
+
raise ValueError("max_tokens must be an integer and at least 1")
|
| 281 |
+
|
| 282 |
+
if not isinstance(collapse_tree, bool):
|
| 283 |
+
raise ValueError("collapse_tree must be a boolean")
|
| 284 |
+
|
| 285 |
+
# Set defaults
|
| 286 |
+
start_layer = self.start_layer if start_layer is None else start_layer
|
| 287 |
+
num_layers = self.num_layers if num_layers is None else num_layers
|
| 288 |
+
|
| 289 |
+
if not isinstance(start_layer, int) or not (
|
| 290 |
+
0 <= start_layer <= self.tree.num_layers
|
| 291 |
+
):
|
| 292 |
+
raise ValueError(
|
| 293 |
+
"start_layer must be an integer between 0 and tree.num_layers"
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
if not isinstance(num_layers, int) or num_layers < 1:
|
| 297 |
+
raise ValueError("num_layers must be an integer and at least 1")
|
| 298 |
+
|
| 299 |
+
if num_layers > (start_layer + 1):
|
| 300 |
+
raise ValueError("num_layers must be less than or equal to start_layer + 1")
|
| 301 |
+
|
| 302 |
+
if collapse_tree:
|
| 303 |
+
logging.info(f"Using collapsed_tree")
|
| 304 |
+
selected_nodes, context = self.retrieve_information_collapse_tree(
|
| 305 |
+
query, top_k, max_tokens
|
| 306 |
+
)
|
| 307 |
+
else:
|
| 308 |
+
layer_nodes = self.tree.layer_to_nodes[start_layer]
|
| 309 |
+
selected_nodes, context = self.retrieve_information(
|
| 310 |
+
layer_nodes, query, num_layers
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
if return_layer_information:
|
| 314 |
+
|
| 315 |
+
layer_information = []
|
| 316 |
+
|
| 317 |
+
for node in selected_nodes:
|
| 318 |
+
layer_information.append(
|
| 319 |
+
{
|
| 320 |
+
"node_index": node.index,
|
| 321 |
+
"layer_number": self.tree_node_index_to_layer[node.index],
|
| 322 |
+
}
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
return context, layer_information
|
| 326 |
+
|
| 327 |
+
return context
|
baselines/raptor/raptor/tree_structures.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Set
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Node:
|
| 5 |
+
"""
|
| 6 |
+
Represents a node in the hierarchical tree structure.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
def __init__(self, text: str, index: int, children: Set[int], embeddings) -> None:
|
| 10 |
+
self.text = text
|
| 11 |
+
self.index = index
|
| 12 |
+
self.children = children
|
| 13 |
+
self.embeddings = embeddings
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Tree:
|
| 17 |
+
"""
|
| 18 |
+
Represents the entire hierarchical tree structure.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self, all_nodes, root_nodes, leaf_nodes, num_layers, layer_to_nodes
|
| 23 |
+
) -> None:
|
| 24 |
+
self.all_nodes = all_nodes
|
| 25 |
+
self.root_nodes = root_nodes
|
| 26 |
+
self.leaf_nodes = leaf_nodes
|
| 27 |
+
self.num_layers = num_layers
|
| 28 |
+
self.layer_to_nodes = layer_to_nodes
|
baselines/raptor/raptor/utils.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import re
|
| 3 |
+
from typing import Dict, List, Set
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import tiktoken
|
| 7 |
+
from scipy import spatial
|
| 8 |
+
|
| 9 |
+
from .tree_structures import Node
|
| 10 |
+
|
| 11 |
+
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def reverse_mapping(layer_to_nodes: Dict[int, List[Node]]) -> Dict[Node, int]:
|
| 15 |
+
node_to_layer = {}
|
| 16 |
+
for layer, nodes in layer_to_nodes.items():
|
| 17 |
+
for node in nodes:
|
| 18 |
+
node_to_layer[node.index] = layer
|
| 19 |
+
return node_to_layer
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def split_text(
|
| 23 |
+
text: str, tokenizer: tiktoken.get_encoding("cl100k_base"), max_tokens: int, overlap: int = 0
|
| 24 |
+
):
|
| 25 |
+
"""
|
| 26 |
+
Splits the input text into smaller chunks based on the tokenizer and maximum allowed tokens.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
text (str): The text to be split.
|
| 30 |
+
tokenizer (CustomTokenizer): The tokenizer to be used for splitting the text.
|
| 31 |
+
max_tokens (int): The maximum allowed tokens.
|
| 32 |
+
overlap (int, optional): The number of overlapping tokens between chunks. Defaults to 0.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
List[str]: A list of text chunks.
|
| 36 |
+
"""
|
| 37 |
+
# Split the text into sentences using multiple delimiters
|
| 38 |
+
delimiters = [".", "!", "?", "\n"]
|
| 39 |
+
regex_pattern = "|".join(map(re.escape, delimiters))
|
| 40 |
+
sentences = re.split(regex_pattern, text)
|
| 41 |
+
|
| 42 |
+
# Calculate the number of tokens for each sentence
|
| 43 |
+
n_tokens = [len(tokenizer.encode(" " + sentence)) for sentence in sentences]
|
| 44 |
+
|
| 45 |
+
chunks = []
|
| 46 |
+
current_chunk = []
|
| 47 |
+
current_length = 0
|
| 48 |
+
|
| 49 |
+
for sentence, token_count in zip(sentences, n_tokens):
|
| 50 |
+
# If the sentence is empty or consists only of whitespace, skip it
|
| 51 |
+
if not sentence.strip():
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
# If the sentence is too long, split it into smaller parts
|
| 55 |
+
if token_count > max_tokens:
|
| 56 |
+
sub_sentences = re.split(r"[,;:]", sentence)
|
| 57 |
+
|
| 58 |
+
# there is no need to keep empty os only-spaced strings
|
| 59 |
+
# since spaces will be inserted in the beginning of the full string
|
| 60 |
+
# and in between the string in the sub_chuk list
|
| 61 |
+
filtered_sub_sentences = [sub.strip() for sub in sub_sentences if sub.strip() != ""]
|
| 62 |
+
sub_token_counts = [len(tokenizer.encode(" " + sub_sentence)) for sub_sentence in filtered_sub_sentences]
|
| 63 |
+
|
| 64 |
+
sub_chunk = []
|
| 65 |
+
sub_length = 0
|
| 66 |
+
|
| 67 |
+
for sub_sentence, sub_token_count in zip(filtered_sub_sentences, sub_token_counts):
|
| 68 |
+
if sub_length + sub_token_count > max_tokens:
|
| 69 |
+
|
| 70 |
+
# if the phrase does not have sub_sentences, it would create an empty chunk
|
| 71 |
+
# this big phrase would be added anyways in the next chunk append
|
| 72 |
+
if sub_chunk:
|
| 73 |
+
chunks.append(" ".join(sub_chunk))
|
| 74 |
+
sub_chunk = sub_chunk[-overlap:] if overlap > 0 else []
|
| 75 |
+
sub_length = sum(sub_token_counts[max(0, len(sub_chunk) - overlap):len(sub_chunk)])
|
| 76 |
+
|
| 77 |
+
sub_chunk.append(sub_sentence)
|
| 78 |
+
sub_length += sub_token_count
|
| 79 |
+
|
| 80 |
+
if sub_chunk:
|
| 81 |
+
chunks.append(" ".join(sub_chunk))
|
| 82 |
+
|
| 83 |
+
# If adding the sentence to the current chunk exceeds the max tokens, start a new chunk
|
| 84 |
+
elif current_length + token_count > max_tokens:
|
| 85 |
+
chunks.append(" ".join(current_chunk))
|
| 86 |
+
current_chunk = current_chunk[-overlap:] if overlap > 0 else []
|
| 87 |
+
current_length = sum(n_tokens[max(0, len(current_chunk) - overlap):len(current_chunk)])
|
| 88 |
+
current_chunk.append(sentence)
|
| 89 |
+
current_length += token_count
|
| 90 |
+
|
| 91 |
+
# Otherwise, add the sentence to the current chunk
|
| 92 |
+
else:
|
| 93 |
+
current_chunk.append(sentence)
|
| 94 |
+
current_length += token_count
|
| 95 |
+
|
| 96 |
+
# Add the last chunk if it's not empty
|
| 97 |
+
if current_chunk:
|
| 98 |
+
chunks.append(" ".join(current_chunk))
|
| 99 |
+
|
| 100 |
+
return chunks
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def distances_from_embeddings(
|
| 104 |
+
query_embedding: List[float],
|
| 105 |
+
embeddings: List[List[float]],
|
| 106 |
+
distance_metric: str = "cosine",
|
| 107 |
+
) -> List[float]:
|
| 108 |
+
"""
|
| 109 |
+
Calculates the distances between a query embedding and a list of embeddings.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
query_embedding (List[float]): The query embedding.
|
| 113 |
+
embeddings (List[List[float]]): A list of embeddings to compare against the query embedding.
|
| 114 |
+
distance_metric (str, optional): The distance metric to use for calculation. Defaults to 'cosine'.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
List[float]: The calculated distances between the query embedding and the list of embeddings.
|
| 118 |
+
"""
|
| 119 |
+
distance_metrics = {
|
| 120 |
+
"cosine": spatial.distance.cosine,
|
| 121 |
+
"L1": spatial.distance.cityblock,
|
| 122 |
+
"L2": spatial.distance.euclidean,
|
| 123 |
+
"Linf": spatial.distance.chebyshev,
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
if distance_metric not in distance_metrics:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
f"Unsupported distance metric '{distance_metric}'. Supported metrics are: {list(distance_metrics.keys())}"
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
distances = [
|
| 132 |
+
distance_metrics[distance_metric](query_embedding, embedding)
|
| 133 |
+
for embedding in embeddings
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
return distances
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def get_node_list(node_dict: Dict[int, Node]) -> List[Node]:
|
| 140 |
+
"""
|
| 141 |
+
Converts a dictionary of node indices to a sorted list of nodes.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
node_dict (Dict[int, Node]): Dictionary of node indices to nodes.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
List[Node]: Sorted list of nodes.
|
| 148 |
+
"""
|
| 149 |
+
indices = sorted(node_dict.keys())
|
| 150 |
+
node_list = [node_dict[index] for index in indices]
|
| 151 |
+
return node_list
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def get_embeddings(node_list: List[Node], embedding_model: str) -> List:
|
| 155 |
+
"""
|
| 156 |
+
Extracts the embeddings of nodes from a list of nodes.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
node_list (List[Node]): List of nodes.
|
| 160 |
+
embedding_model (str): The name of the embedding model to be used.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
List: List of node embeddings.
|
| 164 |
+
"""
|
| 165 |
+
return [node.embeddings[embedding_model] for node in node_list]
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def get_children(node_list: List[Node]) -> List[Set[int]]:
|
| 169 |
+
"""
|
| 170 |
+
Extracts the children of nodes from a list of nodes.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
node_list (List[Node]): List of nodes.
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
List[Set[int]]: List of sets of node children indices.
|
| 177 |
+
"""
|
| 178 |
+
return [node.children for node in node_list]
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def get_text(node_list: List[Node]) -> str:
|
| 182 |
+
"""
|
| 183 |
+
Generates a single text string by concatenating the text from a list of nodes.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
node_list (List[Node]): List of nodes.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
str: Concatenated text.
|
| 190 |
+
"""
|
| 191 |
+
text = ""
|
| 192 |
+
for node in node_list:
|
| 193 |
+
text += f"{' '.join(node.text.splitlines())}"
|
| 194 |
+
text += "\n\n"
|
| 195 |
+
return text
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def indices_of_nearest_neighbors_from_distances(distances: List[float]) -> np.ndarray:
|
| 199 |
+
"""
|
| 200 |
+
Returns the indices of nearest neighbors sorted in ascending order of distance.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
distances (List[float]): A list of distances between embeddings.
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
np.ndarray: An array of indices sorted by ascending distance.
|
| 207 |
+
"""
|
| 208 |
+
return np.argsort(distances)
|
baselines/raptor/requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
faiss-cpu
|
| 2 |
+
numpy==1.26.3
|
| 3 |
+
openai==1.3.3
|
| 4 |
+
scikit-learn
|
| 5 |
+
sentence-transformers==2.2.2
|
| 6 |
+
tenacity==8.2.3
|
| 7 |
+
tiktoken==0.5.1
|
| 8 |
+
torch
|
| 9 |
+
transformers==4.38.1
|
| 10 |
+
umap-learn==0.5.5
|
| 11 |
+
urllib3==1.26.6
|
baselines/raptor/run_raptor_baseline.py
ADDED
|
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RAPTOR baseline for the EvolV-Mem benchmark.
|
| 3 |
+
|
| 4 |
+
Builds a RAPTOR tree per question from session summaries, retrieves context
|
| 5 |
+
via collapsed-tree retrieval, and generates answers using Qwen-30B via vLLM.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python baselines/raptor/run_raptor_baseline.py \
|
| 9 |
+
--in_file dataset/evolv_mem_v4.json \
|
| 10 |
+
--out_file output/raptor_qwen30b.jsonl \
|
| 11 |
+
--summary_file dataset/all_session_summary.json \
|
| 12 |
+
--profile_file metadata/generated_user_profile.json
|
| 13 |
+
|
| 14 |
+
Env vars:
|
| 15 |
+
VLLM_BASE_URL (default http://localhost:8000/v1)
|
| 16 |
+
VLLM_API_KEY (default EMPTY)
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import copy
|
| 21 |
+
import json
|
| 22 |
+
import logging
|
| 23 |
+
import os
|
| 24 |
+
import pickle
|
| 25 |
+
import re
|
| 26 |
+
import sys
|
| 27 |
+
import time
|
| 28 |
+
from abc import ABC, abstractmethod
|
| 29 |
+
from collections import defaultdict
|
| 30 |
+
from typing import Dict, List, Optional
|
| 31 |
+
|
| 32 |
+
from tqdm import tqdm
|
| 33 |
+
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
# Add the raptor package to the path so we can import it directly
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 38 |
+
sys.path.insert(0, SCRIPT_DIR)
|
| 39 |
+
|
| 40 |
+
from raptor import (
|
| 41 |
+
BaseEmbeddingModel,
|
| 42 |
+
BaseQAModel,
|
| 43 |
+
BaseSummarizationModel,
|
| 44 |
+
RetrievalAugmentation,
|
| 45 |
+
RetrievalAugmentationConfig,
|
| 46 |
+
SBertEmbeddingModel,
|
| 47 |
+
TreeRetriever,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
|
| 51 |
+
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
# vLLM-backed Summarization Model
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
|
| 56 |
+
class VLLMSummarizationModel(BaseSummarizationModel):
|
| 57 |
+
"""Summarization model backed by a vLLM OpenAI-compatible server."""
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
model_name: str = None,
|
| 62 |
+
base_url: str = None,
|
| 63 |
+
api_key: str = None,
|
| 64 |
+
):
|
| 65 |
+
from openai import OpenAI
|
| 66 |
+
|
| 67 |
+
self.model_name = (
|
| 68 |
+
model_name
|
| 69 |
+
or os.getenv("VLLM_MODEL_NAME")
|
| 70 |
+
or "Qwen/Qwen3-30B-A3B-Instruct-2507"
|
| 71 |
+
)
|
| 72 |
+
self.client = OpenAI(
|
| 73 |
+
base_url=base_url or os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1"),
|
| 74 |
+
api_key=api_key or os.getenv("VLLM_API_KEY", "EMPTY"),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def summarize(self, context, max_tokens=2048):
|
| 78 |
+
for attempt in range(6):
|
| 79 |
+
try:
|
| 80 |
+
response = self.client.chat.completions.create(
|
| 81 |
+
model=self.model_name,
|
| 82 |
+
messages=[
|
| 83 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 84 |
+
{
|
| 85 |
+
"role": "user",
|
| 86 |
+
"content": (
|
| 87 |
+
"Write a summary of the following, including as many "
|
| 88 |
+
"key details as possible:\n\n"
|
| 89 |
+
f"{context}"
|
| 90 |
+
),
|
| 91 |
+
},
|
| 92 |
+
],
|
| 93 |
+
max_tokens=max_tokens,
|
| 94 |
+
temperature=0.3,
|
| 95 |
+
)
|
| 96 |
+
content = response.choices[0].message.content if response.choices else None
|
| 97 |
+
if content is None:
|
| 98 |
+
wait = min(2 ** attempt * 2, 30)
|
| 99 |
+
print(f"[WARN] LLM returned None content (attempt {attempt+1}); retrying in {wait}s")
|
| 100 |
+
time.sleep(wait)
|
| 101 |
+
continue
|
| 102 |
+
return content.strip()
|
| 103 |
+
except Exception as e:
|
| 104 |
+
msg = str(e).lower()
|
| 105 |
+
if any(code in msg for code in ("429", "500", "503", "rate limit")):
|
| 106 |
+
wait = min(2 ** attempt * 5, 60)
|
| 107 |
+
print(f"[WARN] Summarization retry {attempt+1}/6, sleeping {wait}s: {e}")
|
| 108 |
+
time.sleep(wait)
|
| 109 |
+
continue
|
| 110 |
+
print(f"[ERROR] Summarization failed: {e}")
|
| 111 |
+
raise
|
| 112 |
+
raise RuntimeError("Summarization failed after 6 retries")
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# ---------------------------------------------------------------------------
|
| 116 |
+
# vLLM-backed QA Model
|
| 117 |
+
# ---------------------------------------------------------------------------
|
| 118 |
+
|
| 119 |
+
class VLLMQAModel(BaseQAModel):
|
| 120 |
+
"""QA model backed by a vLLM OpenAI-compatible server.
|
| 121 |
+
|
| 122 |
+
Mutable attributes `question_date` and `user_profile` should be set
|
| 123 |
+
before each call to include per-question context in the prompt.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
model_name: str = None,
|
| 129 |
+
base_url: str = None,
|
| 130 |
+
api_key: str = None,
|
| 131 |
+
):
|
| 132 |
+
from openai import OpenAI
|
| 133 |
+
|
| 134 |
+
self.model_name = (
|
| 135 |
+
model_name
|
| 136 |
+
or os.getenv("VLLM_MODEL_NAME")
|
| 137 |
+
or "Qwen/Qwen3-30B-A3B-Instruct-2507"
|
| 138 |
+
)
|
| 139 |
+
self.client = OpenAI(
|
| 140 |
+
base_url=base_url or os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1"),
|
| 141 |
+
api_key=api_key or os.getenv("VLLM_API_KEY", "EMPTY"),
|
| 142 |
+
)
|
| 143 |
+
# Set these per-question before calling answer_question
|
| 144 |
+
self.question_date: Optional[str] = None
|
| 145 |
+
self.user_profile: Optional[str] = None
|
| 146 |
+
|
| 147 |
+
def answer_question(self, context, question):
|
| 148 |
+
# Build prompt matching the project's answer template (main.py:1490)
|
| 149 |
+
parts = []
|
| 150 |
+
parts.append(
|
| 151 |
+
"I will give you several chat history sessions between you and a user. "
|
| 152 |
+
"Please answer the question given the information."
|
| 153 |
+
)
|
| 154 |
+
if self.user_profile:
|
| 155 |
+
parts.append(f"\n\nUser Profile:\n{self.user_profile}")
|
| 156 |
+
parts.append(f"\n\nChat history sessions:\n\n{context}")
|
| 157 |
+
if self.question_date:
|
| 158 |
+
parts.append(f"\n\nCurrent Date: {self.question_date}")
|
| 159 |
+
parts.append(f"\nQuestion: {question}\nAnswer:")
|
| 160 |
+
|
| 161 |
+
prompt = "".join(parts)
|
| 162 |
+
|
| 163 |
+
for attempt in range(6):
|
| 164 |
+
try:
|
| 165 |
+
response = self.client.chat.completions.create(
|
| 166 |
+
model=self.model_name,
|
| 167 |
+
messages=[{"role": "user", "content": prompt}],
|
| 168 |
+
max_tokens=8192,
|
| 169 |
+
temperature=0.3,
|
| 170 |
+
)
|
| 171 |
+
content = response.choices[0].message.content if response.choices else None
|
| 172 |
+
if content is None:
|
| 173 |
+
wait = min(2 ** attempt * 2, 30)
|
| 174 |
+
print(f"[WARN] LLM returned None content (attempt {attempt+1}); retrying in {wait}s")
|
| 175 |
+
time.sleep(wait)
|
| 176 |
+
continue
|
| 177 |
+
return content.strip()
|
| 178 |
+
except Exception as e:
|
| 179 |
+
msg = str(e).lower()
|
| 180 |
+
if any(code in msg for code in ("429", "500", "503", "rate limit")):
|
| 181 |
+
wait = min(2 ** attempt * 5, 60)
|
| 182 |
+
print(f"[WARN] QA retry {attempt+1}/6, sleeping {wait}s: {e}")
|
| 183 |
+
time.sleep(wait)
|
| 184 |
+
continue
|
| 185 |
+
print(f"[ERROR] QA failed: {e}")
|
| 186 |
+
raise
|
| 187 |
+
raise RuntimeError("QA failed after 6 retries")
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# ---------------------------------------------------------------------------
|
| 191 |
+
# Data helpers
|
| 192 |
+
# ---------------------------------------------------------------------------
|
| 193 |
+
|
| 194 |
+
def prepare_session_documents(
|
| 195 |
+
haystack_session_ids: List[str],
|
| 196 |
+
haystack_dates: List[str],
|
| 197 |
+
summaries: Dict,
|
| 198 |
+
) -> List[str]:
|
| 199 |
+
"""Format session summaries as RAPTOR leaf documents.
|
| 200 |
+
|
| 201 |
+
Each document is a short block:
|
| 202 |
+
[Session {sid} | Date: {date}]
|
| 203 |
+
{session_summary_text}
|
| 204 |
+
"""
|
| 205 |
+
docs = []
|
| 206 |
+
for sid, date_str in zip(haystack_session_ids, haystack_dates):
|
| 207 |
+
summary_data = summaries.get(sid)
|
| 208 |
+
if summary_data is None:
|
| 209 |
+
continue
|
| 210 |
+
text = summary_data.get("session_summary", "")
|
| 211 |
+
if not text:
|
| 212 |
+
# Fallback: join turn summaries
|
| 213 |
+
turn_sums = summary_data.get("turn_summaries", [])
|
| 214 |
+
if turn_sums:
|
| 215 |
+
text = " ".join(turn_sums)
|
| 216 |
+
else:
|
| 217 |
+
continue
|
| 218 |
+
doc = f"[Session {sid} | Date: {date_str}]\n{text}"
|
| 219 |
+
docs.append(doc)
|
| 220 |
+
return docs
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# ---------------------------------------------------------------------------
|
| 224 |
+
# Tree building / caching
|
| 225 |
+
# ---------------------------------------------------------------------------
|
| 226 |
+
|
| 227 |
+
def build_or_load_tree(
|
| 228 |
+
question_id: str,
|
| 229 |
+
docs: List[str],
|
| 230 |
+
tree_builder,
|
| 231 |
+
tree_cache_dir: str,
|
| 232 |
+
):
|
| 233 |
+
"""Build a RAPTOR tree from docs, or load from cache."""
|
| 234 |
+
tree_path = os.path.join(tree_cache_dir, f"{question_id}.pkl")
|
| 235 |
+
|
| 236 |
+
if os.path.exists(tree_path):
|
| 237 |
+
logging.info(f"Loading cached tree for {question_id}")
|
| 238 |
+
with open(tree_path, "rb") as f:
|
| 239 |
+
tree = pickle.load(f)
|
| 240 |
+
return tree
|
| 241 |
+
|
| 242 |
+
# Join docs with double-newline separator so RAPTOR's split_text keeps
|
| 243 |
+
# each ~230-token summary as a single leaf node (with tb_max_tokens=300).
|
| 244 |
+
text = "\n\n".join(docs)
|
| 245 |
+
|
| 246 |
+
logging.info(f"Building tree for {question_id} ({len(docs)} docs)")
|
| 247 |
+
tree = tree_builder.build_from_text(text, use_multithreading=True)
|
| 248 |
+
|
| 249 |
+
os.makedirs(tree_cache_dir, exist_ok=True)
|
| 250 |
+
with open(tree_path, "wb") as f:
|
| 251 |
+
pickle.dump(tree, f)
|
| 252 |
+
logging.info(f"Saved tree for {question_id} -> {tree_path}")
|
| 253 |
+
|
| 254 |
+
return tree
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
# ---------------------------------------------------------------------------
|
| 258 |
+
# Retrieval metrics
|
| 259 |
+
# ---------------------------------------------------------------------------
|
| 260 |
+
|
| 261 |
+
_SESSION_ID_RE = re.compile(r"\[Session\s+(\S+)\s*\|")
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def extract_session_ids_from_context(context: str) -> List[str]:
|
| 265 |
+
"""Parse session IDs from RAPTOR-retrieved context text.
|
| 266 |
+
|
| 267 |
+
Leaf nodes are formatted as '[Session {sid} | Date: ...]\\n{summary}'.
|
| 268 |
+
Higher-level nodes are summaries of clusters and won't contain session IDs.
|
| 269 |
+
"""
|
| 270 |
+
return list(dict.fromkeys(_SESSION_ID_RE.findall(context))) # unique, order-preserving
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def evaluate_retrieval(recalled_docs, correct_docs):
|
| 274 |
+
recall_any = float(any(doc in recalled_docs for doc in correct_docs))
|
| 275 |
+
recall_all = float(all(doc in recalled_docs for doc in correct_docs))
|
| 276 |
+
return recall_any, recall_all
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def print_average_metrics(retrieval_metric_list):
|
| 280 |
+
metric_sums = defaultdict(float)
|
| 281 |
+
metric_counts = defaultdict(int)
|
| 282 |
+
for metric in retrieval_metric_list:
|
| 283 |
+
for k, v in metric.items():
|
| 284 |
+
metric_sums[k] += v
|
| 285 |
+
metric_counts[k] += 1
|
| 286 |
+
print(" Average retrieval metrics:")
|
| 287 |
+
for k in sorted(metric_sums):
|
| 288 |
+
avg = metric_sums[k] / metric_counts[k]
|
| 289 |
+
print(f" {k}: {avg:.4f}")
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
# ---------------------------------------------------------------------------
|
| 293 |
+
# Main
|
| 294 |
+
# ---------------------------------------------------------------------------
|
| 295 |
+
|
| 296 |
+
def main():
|
| 297 |
+
parser = argparse.ArgumentParser(description="RAPTOR baseline for EvolV-Mem")
|
| 298 |
+
parser.add_argument("--in_file", type=str, required=True,
|
| 299 |
+
help="Path to evolv_mem_v4.json")
|
| 300 |
+
parser.add_argument("--out_file", type=str, required=True,
|
| 301 |
+
help="Output JSONL file")
|
| 302 |
+
parser.add_argument("--summary_file", type=str, required=True,
|
| 303 |
+
help="Path to all_session_summary.json")
|
| 304 |
+
parser.add_argument("--profile_file", type=str, default=None,
|
| 305 |
+
help="Path to generated_user_profile.json")
|
| 306 |
+
parser.add_argument("--tree_cache_dir", type=str,
|
| 307 |
+
default="baselines/raptor/trees",
|
| 308 |
+
help="Directory to cache built trees")
|
| 309 |
+
# RAPTOR tree builder params
|
| 310 |
+
parser.add_argument("--tb_max_tokens", type=int, default=300,
|
| 311 |
+
help="Max tokens per leaf chunk (default 300)")
|
| 312 |
+
parser.add_argument("--tb_num_layers", type=int, default=3,
|
| 313 |
+
help="Number of tree layers (default 3)")
|
| 314 |
+
parser.add_argument("--tb_summarization_length", type=int, default=200,
|
| 315 |
+
help="Max tokens per cluster summary (default 200)")
|
| 316 |
+
# RAPTOR retrieval params
|
| 317 |
+
parser.add_argument("--tr_top_k", type=int, default=10,
|
| 318 |
+
help="Top-k nodes to retrieve (default 10)")
|
| 319 |
+
parser.add_argument("--max_retrieval_tokens", type=int, default=8000,
|
| 320 |
+
help="Token budget for retrieved context (default 8000)")
|
| 321 |
+
# Embedding model
|
| 322 |
+
parser.add_argument("--embedding_model", type=str,
|
| 323 |
+
default="sentence-transformers/multi-qa-mpnet-base-cos-v1",
|
| 324 |
+
help="SentenceTransformer model for embeddings")
|
| 325 |
+
# Index range (for parallel jobs)
|
| 326 |
+
parser.add_argument("--start_idx", type=int, default=None,
|
| 327 |
+
help="Start index (inclusive) for question subset")
|
| 328 |
+
parser.add_argument("--end_idx", type=int, default=None,
|
| 329 |
+
help="End index (exclusive) for question subset")
|
| 330 |
+
# Limit (for debugging)
|
| 331 |
+
parser.add_argument("--limit", type=int, default=None,
|
| 332 |
+
help="Process only the first N questions")
|
| 333 |
+
args = parser.parse_args()
|
| 334 |
+
|
| 335 |
+
# -----------------------------------------------------------------------
|
| 336 |
+
# Load data
|
| 337 |
+
# -----------------------------------------------------------------------
|
| 338 |
+
print(f"Loading benchmark from {args.in_file} ...")
|
| 339 |
+
with open(args.in_file) as f:
|
| 340 |
+
benchmark = json.load(f)
|
| 341 |
+
if args.start_idx is not None or args.end_idx is not None:
|
| 342 |
+
s = args.start_idx or 0
|
| 343 |
+
e = args.end_idx or len(benchmark)
|
| 344 |
+
benchmark = benchmark[s:e]
|
| 345 |
+
print(f" Using index range [{s}, {e})")
|
| 346 |
+
if args.limit:
|
| 347 |
+
benchmark = benchmark[: args.limit]
|
| 348 |
+
print(f" {len(benchmark)} questions loaded.")
|
| 349 |
+
|
| 350 |
+
print(f"Loading session summaries from {args.summary_file} ...")
|
| 351 |
+
with open(args.summary_file) as f:
|
| 352 |
+
summaries = json.load(f)
|
| 353 |
+
print(f" {len(summaries)} sessions loaded.")
|
| 354 |
+
|
| 355 |
+
profiles = {}
|
| 356 |
+
if args.profile_file and os.path.exists(args.profile_file):
|
| 357 |
+
print(f"Loading user profiles from {args.profile_file} ...")
|
| 358 |
+
with open(args.profile_file) as f:
|
| 359 |
+
profiles = json.load(f)
|
| 360 |
+
print(f" {len(profiles)} profiles loaded.")
|
| 361 |
+
|
| 362 |
+
# -----------------------------------------------------------------------
|
| 363 |
+
# Resume support: load existing output
|
| 364 |
+
# -----------------------------------------------------------------------
|
| 365 |
+
existing_qids = set()
|
| 366 |
+
if os.path.exists(args.out_file):
|
| 367 |
+
with open(args.out_file) as f:
|
| 368 |
+
for line in f:
|
| 369 |
+
line = line.strip()
|
| 370 |
+
if line:
|
| 371 |
+
obj = json.loads(line)
|
| 372 |
+
existing_qids.add(obj["question_id"])
|
| 373 |
+
print(f" Resuming: {len(existing_qids)} questions already processed.")
|
| 374 |
+
|
| 375 |
+
# -----------------------------------------------------------------------
|
| 376 |
+
# Initialize models
|
| 377 |
+
# -----------------------------------------------------------------------
|
| 378 |
+
print("Initializing models ...")
|
| 379 |
+
embedding_model = SBertEmbeddingModel(model_name=args.embedding_model)
|
| 380 |
+
summarization_model = VLLMSummarizationModel()
|
| 381 |
+
qa_model = VLLMQAModel()
|
| 382 |
+
|
| 383 |
+
# -----------------------------------------------------------------------
|
| 384 |
+
# Build RAPTOR config
|
| 385 |
+
# -----------------------------------------------------------------------
|
| 386 |
+
config = RetrievalAugmentationConfig(
|
| 387 |
+
summarization_model=summarization_model,
|
| 388 |
+
qa_model=qa_model,
|
| 389 |
+
embedding_model=embedding_model,
|
| 390 |
+
tb_max_tokens=args.tb_max_tokens,
|
| 391 |
+
tb_num_layers=args.tb_num_layers,
|
| 392 |
+
tb_summarization_length=args.tb_summarization_length,
|
| 393 |
+
tr_top_k=args.tr_top_k,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
# Pre-create tree builder (reused across questions)
|
| 397 |
+
tree_builder = config.tree_builder_config
|
| 398 |
+
# We need the actual builder instance from a fresh RA to reuse
|
| 399 |
+
ra_template = RetrievalAugmentation(config=config)
|
| 400 |
+
tree_builder_instance = ra_template.tree_builder
|
| 401 |
+
|
| 402 |
+
os.makedirs(args.tree_cache_dir, exist_ok=True)
|
| 403 |
+
|
| 404 |
+
# -----------------------------------------------------------------------
|
| 405 |
+
# Process questions
|
| 406 |
+
# -----------------------------------------------------------------------
|
| 407 |
+
retrieval_metric_list = []
|
| 408 |
+
out_f = open(args.out_file, "a")
|
| 409 |
+
|
| 410 |
+
for di, entry in enumerate(tqdm(benchmark, desc="RAPTOR baseline")):
|
| 411 |
+
qid = entry["question_id"]
|
| 412 |
+
question = entry["question"]
|
| 413 |
+
question_date = entry["question_date"]
|
| 414 |
+
|
| 415 |
+
if qid in existing_qids:
|
| 416 |
+
continue
|
| 417 |
+
|
| 418 |
+
try:
|
| 419 |
+
# 1. Prepare documents from session summaries
|
| 420 |
+
docs = prepare_session_documents(
|
| 421 |
+
entry["haystack_session_ids"],
|
| 422 |
+
entry["haystack_dates"],
|
| 423 |
+
summaries,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
if not docs:
|
| 427 |
+
print(f"[WARN] q_idx={di} qid={qid}: no session summaries found, skipping.")
|
| 428 |
+
result = {
|
| 429 |
+
"q_idx": di,
|
| 430 |
+
"question_id": qid,
|
| 431 |
+
"hypothesis": "Insufficient information to answer.",
|
| 432 |
+
"n_docs": 0,
|
| 433 |
+
}
|
| 434 |
+
print(json.dumps(result), file=out_f, flush=True)
|
| 435 |
+
continue
|
| 436 |
+
|
| 437 |
+
# 2. Build or load RAPTOR tree
|
| 438 |
+
tree = build_or_load_tree(
|
| 439 |
+
qid, docs, tree_builder_instance, args.tree_cache_dir
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
# 3. Create RA instance with this tree
|
| 443 |
+
ra = RetrievalAugmentation(config=config, tree=tree)
|
| 444 |
+
|
| 445 |
+
# 4. Set per-question context on the QA model
|
| 446 |
+
qa_model.question_date = question_date
|
| 447 |
+
user_id = qid.split("_q_")[0] if "_q_" in qid else qid
|
| 448 |
+
qa_model.user_profile = profiles.get(user_id, None)
|
| 449 |
+
|
| 450 |
+
# 5. Retrieve context (separate from QA so we can extract session IDs)
|
| 451 |
+
context, layer_info = ra.retrieve(
|
| 452 |
+
question=question,
|
| 453 |
+
top_k=args.tr_top_k,
|
| 454 |
+
max_tokens=args.max_retrieval_tokens,
|
| 455 |
+
collapse_tree=True,
|
| 456 |
+
return_layer_information=True,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
# 5a. Extract retrieved session IDs from context text
|
| 460 |
+
retrieved_session_ids = extract_session_ids_from_context(context)
|
| 461 |
+
|
| 462 |
+
# 5b. Generate answer manually
|
| 463 |
+
answer = qa_model.answer_question(context, question)
|
| 464 |
+
|
| 465 |
+
# 5c. Compute retrieval metrics
|
| 466 |
+
answer_session_ids = entry.get("answer_session_ids", [])
|
| 467 |
+
retrieval_metric = {}
|
| 468 |
+
if answer_session_ids and retrieved_session_ids:
|
| 469 |
+
for topk in [5, 10, 20, 30]:
|
| 470 |
+
r_any, r_all = evaluate_retrieval(
|
| 471 |
+
retrieved_session_ids[:topk], answer_session_ids
|
| 472 |
+
)
|
| 473 |
+
retrieval_metric[f"recall_any@{topk}"] = r_any
|
| 474 |
+
retrieval_metric[f"recall_all@{topk}"] = r_all
|
| 475 |
+
retrieval_metric_list.append(retrieval_metric)
|
| 476 |
+
print_average_metrics(retrieval_metric_list)
|
| 477 |
+
|
| 478 |
+
# 6. Write output
|
| 479 |
+
result = {
|
| 480 |
+
"q_idx": di,
|
| 481 |
+
"question_id": qid,
|
| 482 |
+
"hypothesis": answer,
|
| 483 |
+
"n_docs": len(docs),
|
| 484 |
+
"n_tree_nodes": len(tree.all_nodes) if hasattr(tree, "all_nodes") else -1,
|
| 485 |
+
"n_tree_layers": tree.num_layers if hasattr(tree, "num_layers") else -1,
|
| 486 |
+
"retrieved_session_ids": retrieved_session_ids,
|
| 487 |
+
"retrieval_metric": retrieval_metric,
|
| 488 |
+
}
|
| 489 |
+
print(json.dumps(result), file=out_f, flush=True)
|
| 490 |
+
|
| 491 |
+
print(
|
| 492 |
+
f"[{di}/{len(benchmark)}] qid={qid} | "
|
| 493 |
+
f"docs={len(docs)} | nodes={result['n_tree_nodes']} | "
|
| 494 |
+
f"layers={result['n_tree_layers']} | "
|
| 495 |
+
f"retrieved_sessions={len(retrieved_session_ids)}"
|
| 496 |
+
)
|
| 497 |
+
print(f" Q: {question[:100]}...")
|
| 498 |
+
print(f" A: {answer[:200]}...")
|
| 499 |
+
|
| 500 |
+
except Exception as e:
|
| 501 |
+
print(f"[ERROR] q_idx={di} qid={qid} failed: {e}", flush=True)
|
| 502 |
+
import traceback
|
| 503 |
+
traceback.print_exc()
|
| 504 |
+
continue
|
| 505 |
+
|
| 506 |
+
out_f.close()
|
| 507 |
+
print(f"\nDone. Results saved to {args.out_file}")
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
if __name__ == "__main__":
|
| 511 |
+
main()
|
baselines/read-agent/read_agent_demo.ipynb
ADDED
|
@@ -0,0 +1,976 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"provenance": [],
|
| 7 |
+
"toc_visible": true
|
| 8 |
+
},
|
| 9 |
+
"kernelspec": {
|
| 10 |
+
"name": "python3",
|
| 11 |
+
"display_name": "Python 3"
|
| 12 |
+
},
|
| 13 |
+
"language_info": {
|
| 14 |
+
"name": "python"
|
| 15 |
+
}
|
| 16 |
+
},
|
| 17 |
+
"cells": [
|
| 18 |
+
{
|
| 19 |
+
"cell_type": "markdown",
|
| 20 |
+
"source": [
|
| 21 |
+
""
|
| 22 |
+
],
|
| 23 |
+
"metadata": {
|
| 24 |
+
"id": "1iqyV7VcsiXT"
|
| 25 |
+
}
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "code",
|
| 29 |
+
"execution_count": null,
|
| 30 |
+
"metadata": {
|
| 31 |
+
"id": "MYOnCMh83ZRE"
|
| 32 |
+
},
|
| 33 |
+
"outputs": [],
|
| 34 |
+
"source": [
|
| 35 |
+
"!wget https://github.com/nyu-mll/quality/raw/main/data/v1.0.1/QuALITY.v1.0.1.htmlstripped.dev\n",
|
| 36 |
+
"import re, time, datetime, json, string, copy"
|
| 37 |
+
]
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"cell_type": "code",
|
| 41 |
+
"source": [
|
| 42 |
+
"# @title Using OpenAI GPT model (DO NOT run the next cell if using GPT)\n",
|
| 43 |
+
"!pip3 install openai\n",
|
| 44 |
+
"import openai\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"key = 'YOUR API KEY' #@param {type: \"string\"}\n",
|
| 47 |
+
"gpt_client = openai.OpenAI(api_key=key)\n",
|
| 48 |
+
"model_type = 'gpt'\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"def query_gpt_model(\n",
|
| 51 |
+
" prompt: str,\n",
|
| 52 |
+
" lm: str = 'gpt-3.5-turbo-1106',\n",
|
| 53 |
+
" temperature: float = 0.0,\n",
|
| 54 |
+
" max_decode_steps: int = 512,\n",
|
| 55 |
+
" seconds_to_reset_tokens: float = 30.0,\n",
|
| 56 |
+
") -> str:\n",
|
| 57 |
+
" while True:\n",
|
| 58 |
+
" try:\n",
|
| 59 |
+
" raw_response = gpt_client.chat.completions.with_raw_response.create(\n",
|
| 60 |
+
" model=lm,\n",
|
| 61 |
+
" max_tokens=max_decode_steps,\n",
|
| 62 |
+
" temperature=temperature,\n",
|
| 63 |
+
" messages=[\n",
|
| 64 |
+
" {'role': 'user', 'content': prompt},\n",
|
| 65 |
+
" ]\n",
|
| 66 |
+
" )\n",
|
| 67 |
+
" completion = raw_response.parse()\n",
|
| 68 |
+
" return completion.choices[0].message.content\n",
|
| 69 |
+
" except openai.RateLimitError as e:\n",
|
| 70 |
+
" print(f'{datetime.datetime.now()}: query_gpt_model: RateLimitError {e.message}: {e}')\n",
|
| 71 |
+
" time.sleep(seconds_to_reset_tokens)\n",
|
| 72 |
+
" except openai.APIError as e:\n",
|
| 73 |
+
" print(f'{datetime.datetime.now()}: query_gpt_model: APIError {e.message}: {e}')\n",
|
| 74 |
+
" print(f'{datetime.datetime.now()}: query_gpt_model: Retrying after 5 seconds...')\n",
|
| 75 |
+
" time.sleep(5)"
|
| 76 |
+
],
|
| 77 |
+
"metadata": {
|
| 78 |
+
"id": "oz0kOxYJ4n3e",
|
| 79 |
+
"cellView": "form"
|
| 80 |
+
},
|
| 81 |
+
"execution_count": null,
|
| 82 |
+
"outputs": []
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"cell_type": "code",
|
| 86 |
+
"source": [
|
| 87 |
+
"# @title Using Google Gemini model (DO NOT run this if using GPT)\n",
|
| 88 |
+
"!pip3 install -q -U google-generativeai\n",
|
| 89 |
+
"import google.generativeai as genai\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"key = 'YOUR API KEY' #@param {type: \"string\"}\n",
|
| 92 |
+
"\n",
|
| 93 |
+
"genai.configure(api_key=key)\n",
|
| 94 |
+
"model = genai.GenerativeModel('gemini-pro')\n",
|
| 95 |
+
"model_type = 'gemini'\n",
|
| 96 |
+
"\n",
|
| 97 |
+
"def query_gemini_model(\n",
|
| 98 |
+
" prompt: str,\n",
|
| 99 |
+
" retries: int = 10,\n",
|
| 100 |
+
") -> str:\n",
|
| 101 |
+
" while True and retries > 0:\n",
|
| 102 |
+
" try:\n",
|
| 103 |
+
" response = model.generate_content(prompt)\n",
|
| 104 |
+
" text_response = response.text.replace(\"**\", \"\")\n",
|
| 105 |
+
" return text_response\n",
|
| 106 |
+
" except Exception as e:\n",
|
| 107 |
+
" print(f'{datetime.datetime.now()}: query_gemini_model: Error: {e}')\n",
|
| 108 |
+
" print(f'{datetime.datetime.now()}: query_gemini_model: Retrying after 5 seconds...')\n",
|
| 109 |
+
" retries -= 1\n",
|
| 110 |
+
" time.sleep(5)"
|
| 111 |
+
],
|
| 112 |
+
"metadata": {
|
| 113 |
+
"cellView": "form",
|
| 114 |
+
"id": "YcP_tIpZKNFY"
|
| 115 |
+
},
|
| 116 |
+
"execution_count": null,
|
| 117 |
+
"outputs": []
|
| 118 |
+
},
|
| 119 |
+
{
|
| 120 |
+
"cell_type": "code",
|
| 121 |
+
"source": [
|
| 122 |
+
"def query_model(prompt):\n",
|
| 123 |
+
" if model_type == \"gpt\":\n",
|
| 124 |
+
" return query_gpt_model(prompt)\n",
|
| 125 |
+
" elif model_type == \"gemini\":\n",
|
| 126 |
+
" return query_gemini_model(prompt)"
|
| 127 |
+
],
|
| 128 |
+
"metadata": {
|
| 129 |
+
"id": "pYm2GsBGEvAI"
|
| 130 |
+
},
|
| 131 |
+
"execution_count": null,
|
| 132 |
+
"outputs": []
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"cell_type": "code",
|
| 136 |
+
"source": [
|
| 137 |
+
"#@title Load a QuALITY example\n",
|
| 138 |
+
"\n",
|
| 139 |
+
"# Fields that are straight text copies from raw example to processed example.\n",
|
| 140 |
+
"_ONE2ONE_FIELDS = (\n",
|
| 141 |
+
" 'article',\n",
|
| 142 |
+
" 'article_id',\n",
|
| 143 |
+
" 'set_unique_id',\n",
|
| 144 |
+
" 'writer_id',\n",
|
| 145 |
+
" 'source',\n",
|
| 146 |
+
" 'title',\n",
|
| 147 |
+
" 'topic',\n",
|
| 148 |
+
" 'url',\n",
|
| 149 |
+
" 'writer_id',\n",
|
| 150 |
+
" 'author',\n",
|
| 151 |
+
")\n",
|
| 152 |
+
"\n",
|
| 153 |
+
"quality_dev = []\n",
|
| 154 |
+
"\n",
|
| 155 |
+
"with open('QuALITY.v1.0.1.htmlstripped.dev', 'r') as f:\n",
|
| 156 |
+
" for line in f.readlines():\n",
|
| 157 |
+
" j = json.loads(line)\n",
|
| 158 |
+
" fields = {k: j[k] for k in _ONE2ONE_FIELDS}\n",
|
| 159 |
+
" fields.update({\n",
|
| 160 |
+
" 'questions': [q['question'] for q in j['questions']],\n",
|
| 161 |
+
" 'question_ids': [q['question_unique_id'] for q in j['questions']],\n",
|
| 162 |
+
" 'difficults': [q['difficult'] for q in j['questions']],\n",
|
| 163 |
+
" 'options': [q['options'] for q in j['questions']],\n",
|
| 164 |
+
" })\n",
|
| 165 |
+
"\n",
|
| 166 |
+
" fields.update({\n",
|
| 167 |
+
" 'gold_labels': [q['gold_label'] for q in j['questions']],\n",
|
| 168 |
+
" 'writer_labels': [q['writer_label'] for q in j['questions']],\n",
|
| 169 |
+
" })\n",
|
| 170 |
+
"\n",
|
| 171 |
+
" quality_dev.append(fields)\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"example = quality_dev[13]"
|
| 174 |
+
],
|
| 175 |
+
"metadata": {
|
| 176 |
+
"id": "1B70Rqg97aXu",
|
| 177 |
+
"cellView": "form"
|
| 178 |
+
},
|
| 179 |
+
"execution_count": null,
|
| 180 |
+
"outputs": []
|
| 181 |
+
},
|
| 182 |
+
{
|
| 183 |
+
"cell_type": "code",
|
| 184 |
+
"source": [
|
| 185 |
+
"#@title Helper functions\n",
|
| 186 |
+
"\n",
|
| 187 |
+
"all_lowercase_letters = string.ascii_lowercase # \"abcd...xyz\"\n",
|
| 188 |
+
"bracketed_lowercase_letters_set = set(\n",
|
| 189 |
+
" [f\"({l})\" for l in all_lowercase_letters]\n",
|
| 190 |
+
") # {\"(a)\", ...}\n",
|
| 191 |
+
"bracketed_uppercase_letters_set = set(\n",
|
| 192 |
+
" [f\"({l.upper()})\" for l in all_lowercase_letters]\n",
|
| 193 |
+
") # {\"(a)\", ...}\n",
|
| 194 |
+
"\n",
|
| 195 |
+
"choices = ['(A)', '(B)', '(C)', '(D)']\n",
|
| 196 |
+
"\n",
|
| 197 |
+
"def get_index_from_symbol(answer):\n",
|
| 198 |
+
" \"\"\"Get the index from the letter symbols A, B, C, D, to extract answer texts.\n",
|
| 199 |
+
"\n",
|
| 200 |
+
" Args:\n",
|
| 201 |
+
" answer (str): the string of answer like \"(B)\".\n",
|
| 202 |
+
"\n",
|
| 203 |
+
" Returns:\n",
|
| 204 |
+
" index (int): how far the given choice is from \"a\", like 1 for answer \"(B)\".\n",
|
| 205 |
+
" \"\"\"\n",
|
| 206 |
+
" answer = str(answer).lower()\n",
|
| 207 |
+
" # extract the choice letter from within bracket\n",
|
| 208 |
+
" if answer in bracketed_lowercase_letters_set:\n",
|
| 209 |
+
" answer = re.findall(r\"\\(.*?\\)\", answer)[0][1]\n",
|
| 210 |
+
" index = ord(answer) - ord(\"a\")\n",
|
| 211 |
+
" return index\n",
|
| 212 |
+
"\n",
|
| 213 |
+
"def count_words(text):\n",
|
| 214 |
+
" \"\"\"Simple word counting.\"\"\"\n",
|
| 215 |
+
" return len(text.split())\n",
|
| 216 |
+
"\n",
|
| 217 |
+
"def quality_gutenberg_parser(raw_article):\n",
|
| 218 |
+
" \"\"\"Parse Gutenberg articles in the QuALITY dataset.\"\"\"\n",
|
| 219 |
+
" lines = []\n",
|
| 220 |
+
" previous_line = None\n",
|
| 221 |
+
" for i, line in enumerate(raw_article.split('\\n')):\n",
|
| 222 |
+
" line = line.strip()\n",
|
| 223 |
+
" original_line = line\n",
|
| 224 |
+
" if line == '':\n",
|
| 225 |
+
" if previous_line == '':\n",
|
| 226 |
+
" line = '\\n'\n",
|
| 227 |
+
" else:\n",
|
| 228 |
+
" previous_line = original_line\n",
|
| 229 |
+
" continue\n",
|
| 230 |
+
" previous_line = original_line\n",
|
| 231 |
+
" lines.append(line)\n",
|
| 232 |
+
" return ' '.join(lines)"
|
| 233 |
+
],
|
| 234 |
+
"metadata": {
|
| 235 |
+
"id": "nQsb3n6pOlz2",
|
| 236 |
+
"cellView": "form"
|
| 237 |
+
},
|
| 238 |
+
"execution_count": null,
|
| 239 |
+
"outputs": []
|
| 240 |
+
},
|
| 241 |
+
{
|
| 242 |
+
"cell_type": "code",
|
| 243 |
+
"source": [
|
| 244 |
+
"#@title ReadAgent (1) Episode Pagination\n",
|
| 245 |
+
"\n",
|
| 246 |
+
"prompt_pagination_template = \"\"\"\n",
|
| 247 |
+
"You are given a passage that is taken from a larger text (article, book, ...) and some numbered labels between the paragraphs in the passage.\n",
|
| 248 |
+
"Numbered label are in angeled brackets. For example, if the label number is 19, it shows as <19> in text.\n",
|
| 249 |
+
"Please choose one label that it is natural to break reading.\n",
|
| 250 |
+
"Such point can be scene transition, end of a dialogue, end of an argument, narrative transition, etc.\n",
|
| 251 |
+
"Please answer the break point label and explain.\n",
|
| 252 |
+
"For example, if <57> is a good point to break, answer with \\\"Break point: <57>\\n Because ...\\\"\n",
|
| 253 |
+
"\n",
|
| 254 |
+
"Passage:\n",
|
| 255 |
+
"\n",
|
| 256 |
+
"{0}\n",
|
| 257 |
+
"{1}\n",
|
| 258 |
+
"{2}\n",
|
| 259 |
+
"\n",
|
| 260 |
+
"\"\"\"\n",
|
| 261 |
+
"\n",
|
| 262 |
+
"def parse_pause_point(text):\n",
|
| 263 |
+
" text = text.strip(\"Break point: \")\n",
|
| 264 |
+
" if text[0] != '<':\n",
|
| 265 |
+
" return None\n",
|
| 266 |
+
" for i, c in enumerate(text):\n",
|
| 267 |
+
" if c == '>':\n",
|
| 268 |
+
" if text[1:i].isnumeric():\n",
|
| 269 |
+
" return int(text[1:i])\n",
|
| 270 |
+
" else:\n",
|
| 271 |
+
" return None\n",
|
| 272 |
+
" return None\n",
|
| 273 |
+
"\n",
|
| 274 |
+
"\n",
|
| 275 |
+
"def quality_pagination(example,\n",
|
| 276 |
+
" word_limit=600,\n",
|
| 277 |
+
" start_threshold=280,\n",
|
| 278 |
+
" max_retires=10,\n",
|
| 279 |
+
" verbose=True,\n",
|
| 280 |
+
" allow_fallback_to_last=True):\n",
|
| 281 |
+
" article = example['article']\n",
|
| 282 |
+
" title = example['title']\n",
|
| 283 |
+
" print(f\"[Pagination][Article {title}]\")\n",
|
| 284 |
+
" paragraphs = quality_gutenberg_parser(article).split('\\n')\n",
|
| 285 |
+
"\n",
|
| 286 |
+
" i = 0\n",
|
| 287 |
+
" pages = []\n",
|
| 288 |
+
" while i < len(paragraphs):\n",
|
| 289 |
+
" preceding = \"\" if i == 0 else \"...\\n\" + '\\n'.join(pages[-1])\n",
|
| 290 |
+
" passage = [paragraphs[i]]\n",
|
| 291 |
+
" wcount = count_words(paragraphs[i])\n",
|
| 292 |
+
" j = i + 1\n",
|
| 293 |
+
" while wcount < word_limit and j < len(paragraphs):\n",
|
| 294 |
+
" wcount += count_words(paragraphs[j])\n",
|
| 295 |
+
" if wcount >= start_threshold:\n",
|
| 296 |
+
" passage.append(f\"<{j}>\")\n",
|
| 297 |
+
" passage.append(paragraphs[j])\n",
|
| 298 |
+
" j += 1\n",
|
| 299 |
+
" passage.append(f\"<{j}>\")\n",
|
| 300 |
+
" end_tag = \"\" if j == len(paragraphs) else paragraphs[j] + \"\\n...\"\n",
|
| 301 |
+
"\n",
|
| 302 |
+
" pause_point = None\n",
|
| 303 |
+
" if wcount < 350:\n",
|
| 304 |
+
" pause_point = len(paragraphs)\n",
|
| 305 |
+
" else:\n",
|
| 306 |
+
" prompt = prompt_pagination_template.format(preceding, '\\n'.join(passage), end_tag)\n",
|
| 307 |
+
" response = query_model(prompt=prompt).strip()\n",
|
| 308 |
+
" pause_point = parse_pause_point(response)\n",
|
| 309 |
+
" if pause_point and (pause_point <= i or pause_point > j):\n",
|
| 310 |
+
" print(f\"prompt:\\n{prompt},\\nresponse:\\n{response}\\n\")\n",
|
| 311 |
+
" print(f\"i:{i} j:{j} pause_point:{pause_point}\")\n",
|
| 312 |
+
" pause_point = None\n",
|
| 313 |
+
" if pause_point is None:\n",
|
| 314 |
+
" if allow_fallback_to_last:\n",
|
| 315 |
+
" pause_point = j\n",
|
| 316 |
+
" else:\n",
|
| 317 |
+
" raise ValueError(f\"prompt:\\n{prompt},\\nresponse:\\n{response}\\n\")\n",
|
| 318 |
+
"\n",
|
| 319 |
+
" page = paragraphs[i:pause_point]\n",
|
| 320 |
+
" pages.append(page)\n",
|
| 321 |
+
" if verbose:\n",
|
| 322 |
+
" print(f\"Paragraph {i}-{pause_point-1}\", page)\n",
|
| 323 |
+
" i = pause_point\n",
|
| 324 |
+
" print(f\"[Pagination] Done with {len(pages)} pages\")\n",
|
| 325 |
+
" return pages\n",
|
| 326 |
+
"\n",
|
| 327 |
+
"pages = quality_pagination(example)"
|
| 328 |
+
],
|
| 329 |
+
"metadata": {
|
| 330 |
+
"id": "BfFkEQKx0u9U"
|
| 331 |
+
},
|
| 332 |
+
"execution_count": null,
|
| 333 |
+
"outputs": []
|
| 334 |
+
},
|
| 335 |
+
{
|
| 336 |
+
"cell_type": "code",
|
| 337 |
+
"source": [
|
| 338 |
+
"#@title ReadAgent (2) Memory Gisting\n",
|
| 339 |
+
"\n",
|
| 340 |
+
"prompt_shorten_template = \"\"\"\n",
|
| 341 |
+
"Please shorten the following passage.\n",
|
| 342 |
+
"Just give me a shortened version. DO NOT explain your reason.\n",
|
| 343 |
+
"\n",
|
| 344 |
+
"Passage:\n",
|
| 345 |
+
"{}\n",
|
| 346 |
+
"\n",
|
| 347 |
+
"\"\"\"\n",
|
| 348 |
+
"\n",
|
| 349 |
+
"def quality_gisting(example, pages, word_limit=600, start_threshold=280, verbose=True):\n",
|
| 350 |
+
" article = example['article']\n",
|
| 351 |
+
" title = example['title']\n",
|
| 352 |
+
" word_count = count_words(article)\n",
|
| 353 |
+
" print(f\"[Gisting][Article {title}], {word_count} words\")\n",
|
| 354 |
+
"\n",
|
| 355 |
+
" shortened_pages = []\n",
|
| 356 |
+
" for i, page in enumerate(pages):\n",
|
| 357 |
+
" prompt = prompt_shorten_template.format('\\n'.join(page))\n",
|
| 358 |
+
" response = query_model(prompt)\n",
|
| 359 |
+
" shortened_text = response.strip()\n",
|
| 360 |
+
" shortened_pages.append(shortened_text)\n",
|
| 361 |
+
" if verbose:\n",
|
| 362 |
+
" print(\"[gist] page {}:\".format(i), shortened_text, flush=True)\n",
|
| 363 |
+
" shortened_article = '\\n'.join(shortened_pages)\n",
|
| 364 |
+
" gist_word_count = count_words(shortened_article)\n",
|
| 365 |
+
" if verbose:\n",
|
| 366 |
+
" print(\"Shortened article:\\n\", shortened_article, flush=True)\n",
|
| 367 |
+
" output = copy.deepcopy(example)\n",
|
| 368 |
+
" output.update({'title': title, 'word_count': word_count, 'gist_word_count': gist_word_count, 'shortened_pages': shortened_pages, 'pages': pages})\n",
|
| 369 |
+
" if verbose:\n",
|
| 370 |
+
" print(f\"compression rate {round(100.0 - gist_word_count/word_count*100, 2)}% ({gist_word_count}/{word_count})\")\n",
|
| 371 |
+
" return output\n",
|
| 372 |
+
"example_with_gists = quality_gisting(example, pages)"
|
| 373 |
+
],
|
| 374 |
+
"metadata": {
|
| 375 |
+
"id": "DLBolKnkS_9y"
|
| 376 |
+
},
|
| 377 |
+
"execution_count": null,
|
| 378 |
+
"outputs": []
|
| 379 |
+
},
|
| 380 |
+
{
|
| 381 |
+
"cell_type": "code",
|
| 382 |
+
"source": [
|
| 383 |
+
"#@title ReadAgent (3) Look-Up\n",
|
| 384 |
+
"\n",
|
| 385 |
+
"prompt_lookup_template = \"\"\"\n",
|
| 386 |
+
"The following text is what you remembered from reading an article and a multiple choice question related to it.\n",
|
| 387 |
+
"You may read 1 to 6 page(s) of the article again to refresh your memory to prepare yourselve for the question.\n",
|
| 388 |
+
"Please respond with which page(s) you would like to read.\n",
|
| 389 |
+
"For example, if your only need to read Page 8, respond with \\\"I want to look up Page [8] to ...\\\";\n",
|
| 390 |
+
"if your would like to read Page 7 and 12, respond with \\\"I want to look up Page [7, 12] to ...\\\";\n",
|
| 391 |
+
"if your would like to read Page 2, 3, 7, 15 and 18, respond with \\\"I want to look up Page [2, 3, 7, 15, 18] to ...\\\".\n",
|
| 392 |
+
"if your would like to read Page 3, 4, 5, 12, 13 and 16, respond with \\\"I want to look up Page [3, 3, 4, 12, 13, 16] to ...\\\".\n",
|
| 393 |
+
"DO NOT select more pages if you don't need to.\n",
|
| 394 |
+
"DO NOT answer the question yet.\n",
|
| 395 |
+
"\n",
|
| 396 |
+
"Text:\n",
|
| 397 |
+
"{}\n",
|
| 398 |
+
"\n",
|
| 399 |
+
"Question:\n",
|
| 400 |
+
"{}\n",
|
| 401 |
+
"{}\n",
|
| 402 |
+
"\n",
|
| 403 |
+
"Take a deep breath and tell me: Which page(s) would you like to read again?\n",
|
| 404 |
+
"\"\"\"\n",
|
| 405 |
+
"\n",
|
| 406 |
+
"prompt_answer_template = \"\"\"\n",
|
| 407 |
+
"Read the following article and answer a multiple choice question.\n",
|
| 408 |
+
"For example, if (C) is correct, answer with \\\"Answer: (C) ...\\\"\n",
|
| 409 |
+
"\n",
|
| 410 |
+
"Article:\n",
|
| 411 |
+
"{}\n",
|
| 412 |
+
"\n",
|
| 413 |
+
"Question:\n",
|
| 414 |
+
"{}\n",
|
| 415 |
+
"{}\n",
|
| 416 |
+
"\n",
|
| 417 |
+
"\"\"\"\n",
|
| 418 |
+
"\n",
|
| 419 |
+
"def quality_parallel_lookup(example, verbose=True):\n",
|
| 420 |
+
" preprocessed_pages = example['pages']\n",
|
| 421 |
+
" article = example['article']\n",
|
| 422 |
+
" title = example['title']\n",
|
| 423 |
+
" word_count = example['word_count']\n",
|
| 424 |
+
" gist_word_count = example['gist_word_count']\n",
|
| 425 |
+
" pages = example['pages']\n",
|
| 426 |
+
" shortened_pages = example['shortened_pages']\n",
|
| 427 |
+
" questions = example['questions']\n",
|
| 428 |
+
" options = example['options']\n",
|
| 429 |
+
" gold_labels = example['gold_labels'] # numerical [1, 2, 3, 4]\n",
|
| 430 |
+
"\n",
|
| 431 |
+
" print(f\"[Look-Up][Article {title}] {word_count} words\")\n",
|
| 432 |
+
"\n",
|
| 433 |
+
" model_choices = []\n",
|
| 434 |
+
" lookup_page_ids = []\n",
|
| 435 |
+
"\n",
|
| 436 |
+
" shortened_pages_pidx = []\n",
|
| 437 |
+
" for i, shortened_text in enumerate(shortened_pages):\n",
|
| 438 |
+
" shortened_pages_pidx.append(\"<Page {}>\\n\".format(i) + shortened_text)\n",
|
| 439 |
+
" shortened_article = '\\n'.join(shortened_pages_pidx)\n",
|
| 440 |
+
"\n",
|
| 441 |
+
" expanded_gist_word_counts = []\n",
|
| 442 |
+
" for i, label in enumerate(gold_labels):\n",
|
| 443 |
+
" # only test the first question for demo\n",
|
| 444 |
+
" if i != 1:\n",
|
| 445 |
+
" continue\n",
|
| 446 |
+
" q = questions[i]\n",
|
| 447 |
+
" print(\"question: \", q)\n",
|
| 448 |
+
" options_i = [f\"{ol} {o}\" for ol, o in zip(choices, options[i])]\n",
|
| 449 |
+
" print(\"options: \", \"\\n\".join(options_i))\n",
|
| 450 |
+
" prompt_lookup = prompt_lookup_template.format(shortened_article, q, '\\n'.join(options_i))\n",
|
| 451 |
+
"\n",
|
| 452 |
+
" page_ids = []\n",
|
| 453 |
+
"\n",
|
| 454 |
+
" response = query_model(prompt=prompt_lookup).strip()\n",
|
| 455 |
+
"\n",
|
| 456 |
+
" try: start = response.index('[')\n",
|
| 457 |
+
" except ValueError: start = len(response)\n",
|
| 458 |
+
" try: end = response.index(']')\n",
|
| 459 |
+
" except ValueError: end = 0\n",
|
| 460 |
+
" if start < end:\n",
|
| 461 |
+
" page_ids_str = response[start+1:end].split(',')\n",
|
| 462 |
+
" page_ids = []\n",
|
| 463 |
+
" for p in page_ids_str:\n",
|
| 464 |
+
" if p.strip().isnumeric():\n",
|
| 465 |
+
" page_id = int(p)\n",
|
| 466 |
+
" if page_id < 0 or page_id >= len(pages):\n",
|
| 467 |
+
" print(\"Skip invalid page number: \", page_id, flush=True)\n",
|
| 468 |
+
" else:\n",
|
| 469 |
+
" page_ids.append(page_id)\n",
|
| 470 |
+
"\n",
|
| 471 |
+
" if verbose:\n",
|
| 472 |
+
" print(\"Model chose to look up page {}\".format(page_ids))\n",
|
| 473 |
+
"\n",
|
| 474 |
+
" # Memory expansion after look-up, replacing the target shortened page with the original page\n",
|
| 475 |
+
" expanded_shortened_pages = shortened_pages[:]\n",
|
| 476 |
+
" if len(page_ids) > 0:\n",
|
| 477 |
+
" for page_id in page_ids:\n",
|
| 478 |
+
" expanded_shortened_pages[page_id] = '\\n'.join(pages[page_id])\n",
|
| 479 |
+
"\n",
|
| 480 |
+
" expanded_shortened_article = '\\n'.join(expanded_shortened_pages)\n",
|
| 481 |
+
" expanded_gist_word_count = count_words(expanded_shortened_article)\n",
|
| 482 |
+
" if verbose:\n",
|
| 483 |
+
" print(\"Expanded shortened article:\\n\", expanded_shortened_article, flush=True)\n",
|
| 484 |
+
" prompt_answer = prompt_answer_template.format(expanded_shortened_article, q, '\\n'.join(options_i))\n",
|
| 485 |
+
"\n",
|
| 486 |
+
" # If the response doesn't follow the template, retry\n",
|
| 487 |
+
" model_choice = None\n",
|
| 488 |
+
" response = query_model(prompt=prompt_answer)\n",
|
| 489 |
+
" response = response.strip()\n",
|
| 490 |
+
" for j, choice in enumerate(choices):\n",
|
| 491 |
+
" if response.startswith(f\"Answer: {choice}\") or response.startswith(f\"Answer: {choice[1]}\"):\n",
|
| 492 |
+
" model_choice = j+1\n",
|
| 493 |
+
" break\n",
|
| 494 |
+
" is_correct = 1 if model_choice == label else 0\n",
|
| 495 |
+
" print(f\"question: {q}\")\n",
|
| 496 |
+
" print(f\"reference answer: {choices[label]}, model prediction: {choices[model_choice]}, is_correct: {is_correct}\")\n",
|
| 497 |
+
" print(f\"compression rate {round(100.0 - gist_word_count/word_count*100, 2)}% ({gist_word_count}/{word_count})\")\n",
|
| 498 |
+
" print(f\"compression rate after look-up {round(100.0 - expanded_gist_word_count/word_count*100, 2)}% ({expanded_gist_word_count}/{word_count})\")\n",
|
| 499 |
+
"\n",
|
| 500 |
+
"quality_parallel_lookup(example_with_gists)"
|
| 501 |
+
],
|
| 502 |
+
"metadata": {
|
| 503 |
+
"id": "8YKNTyDsXNIn"
|
| 504 |
+
},
|
| 505 |
+
"execution_count": null,
|
| 506 |
+
"outputs": []
|
| 507 |
+
},
|
| 508 |
+
{
|
| 509 |
+
"cell_type": "markdown",
|
| 510 |
+
"source": [
|
| 511 |
+
"#Prompts that we used in the paper\n",
|
| 512 |
+
"\n",
|
| 513 |
+
"In the following we show the prompts that were used for the QuALTIY, QMSum, NarrativeQA datasets with the PaLM 2-L model. While there are slight differences in prompt design, most of these are not due to optimizing prompts for specific datasets but rather a results of that each author wrote the prompts independently."
|
| 514 |
+
],
|
| 515 |
+
"metadata": {
|
| 516 |
+
"id": "Gn8fjomx7iRz"
|
| 517 |
+
}
|
| 518 |
+
},
|
| 519 |
+
{
|
| 520 |
+
"cell_type": "code",
|
| 521 |
+
"source": [
|
| 522 |
+
"# @title The prompts we used for QuALITY with PaLM 2-L\n",
|
| 523 |
+
"\n",
|
| 524 |
+
"\n",
|
| 525 |
+
"# Pagination\n",
|
| 526 |
+
"pagination_prompt_template = \"\"\"\n",
|
| 527 |
+
"You are given a passage that is taken from a larger text (article, book, ...) and some numbered labels between the paragraphs in the passage.\n",
|
| 528 |
+
"Numbered label are in angeled brackets. For example, if the label number is 19, it shows as <19> in text.\n",
|
| 529 |
+
"Please choose one label that it is natural to break reading.\n",
|
| 530 |
+
"Such point can be scene transition, end of a dialogue, end of an argument, narrative transition, etc.\n",
|
| 531 |
+
"Please answer the break point label and explain.\n",
|
| 532 |
+
"For example, if <57> is a good point to break, answer with \\\"Break point: <57>\\n Because ...\\\"\n",
|
| 533 |
+
"\n",
|
| 534 |
+
"Passage:\n",
|
| 535 |
+
"\n",
|
| 536 |
+
"{passage_text}\n",
|
| 537 |
+
"{end_tag}\n",
|
| 538 |
+
"\n",
|
| 539 |
+
"\"\"\"\n",
|
| 540 |
+
"# passage_text: a chunk of text.\n",
|
| 541 |
+
"# end_tag: a string, whose value is \"\" if the text is at the end of the article, and otherwise \"\\n...\".\n",
|
| 542 |
+
"\n",
|
| 543 |
+
"\n",
|
| 544 |
+
"\n",
|
| 545 |
+
"# Gisting\n",
|
| 546 |
+
"gisting_prompt_template = \"\"\"\n",
|
| 547 |
+
"Please shorten the following passage.\n",
|
| 548 |
+
"Just give me a shortened version. DO NOT explain your reason.\n",
|
| 549 |
+
"\n",
|
| 550 |
+
"Passage:\n",
|
| 551 |
+
"{page_text}\n",
|
| 552 |
+
"\n",
|
| 553 |
+
"\"\"\"\n",
|
| 554 |
+
"# page_text: a page of text\n",
|
| 555 |
+
"\n",
|
| 556 |
+
"\n",
|
| 557 |
+
"\n",
|
| 558 |
+
"# Parallel Look-up (ReadAgent-P, up to 5 pages)\n",
|
| 559 |
+
"parallel_lookup_prompt_template = \"\"\"\n",
|
| 560 |
+
"The following text is what you remembered from reading an article and a multiple choice question related to it.\n",
|
| 561 |
+
"You may read 1 to 5 page(s) of the article again to refresh your memory to prepare yourselve for the question.\n",
|
| 562 |
+
"Please respond with which page(s) you would like to read again.\n",
|
| 563 |
+
"For example, if your would like to only read Page 8, respond with \\\"I want to look up Page [8] to ...\\\";\n",
|
| 564 |
+
"if your would like to read Page 7 and 12, respond with \\\"I want to look up Page [7, 12] to ...\\\";\n",
|
| 565 |
+
"if your would like to read Page 2, 3, 7, 15 and 18, respond with \\\"I want to look up Page [2, 3, 7, 15, 18] to ...\\\".\n",
|
| 566 |
+
"DO NOT select more pages if you don't need to.\n",
|
| 567 |
+
"DO NOT answer the question yet.\n",
|
| 568 |
+
"\n",
|
| 569 |
+
"Text:\n",
|
| 570 |
+
"{concatenated_gists}\n",
|
| 571 |
+
"\n",
|
| 572 |
+
"Question:\n",
|
| 573 |
+
"{question}\n",
|
| 574 |
+
"{options}\n",
|
| 575 |
+
"\n",
|
| 576 |
+
"Take a deep breath and tell me: Which page(s) would you like to read again?\n",
|
| 577 |
+
"\"\"\"\n",
|
| 578 |
+
"# concatenated_gists: concatenated gists\n",
|
| 579 |
+
"# question: a question\n",
|
| 580 |
+
"# options: multiple-choice options\n",
|
| 581 |
+
"\n",
|
| 582 |
+
"\n",
|
| 583 |
+
"\n",
|
| 584 |
+
"# Sequential Look-up (ReadAgent-S, up to 5 pages)\n",
|
| 585 |
+
"sequential_lookup_prompt_template = \"\"\"\n",
|
| 586 |
+
"The following text is what you remember from reading an article, followed by a question about the article.\n",
|
| 587 |
+
"You may read multiple pages of the article again to refresh your memory and prepare to answer the question.\n",
|
| 588 |
+
"Each page that you re-read can significantly improve your chance of answering the question correctly.\n",
|
| 589 |
+
"Please specify a SINGLE page you would like to read again or say \"STOP\".\n",
|
| 590 |
+
"To read a page again, respond with \"Page $PAGE_NUM\", replacing $PAGE_NUM with the target page number.\n",
|
| 591 |
+
"You can only specify a SINGLE page in your response at this time.\n",
|
| 592 |
+
"DO NOT select more pages if you don't need to.\n",
|
| 593 |
+
"To stop, simply say \"STOP\".\n",
|
| 594 |
+
"DO NOT answer the question in your response.\n",
|
| 595 |
+
"\n",
|
| 596 |
+
"Text:\n",
|
| 597 |
+
"{concatenated_gists}\n",
|
| 598 |
+
"End of text.\n",
|
| 599 |
+
"\n",
|
| 600 |
+
"Pages re-read already (DO NOT ask to read them again):\n",
|
| 601 |
+
"{past_page_numbers}\n",
|
| 602 |
+
"\n",
|
| 603 |
+
"Question:\n",
|
| 604 |
+
"{question}\n",
|
| 605 |
+
"{options}\n",
|
| 606 |
+
"\n",
|
| 607 |
+
"Specify a SINGLE page to read again, or say STOP:\n",
|
| 608 |
+
"\"\"\"\n",
|
| 609 |
+
"# concatenated_gists: concatenated gists\n",
|
| 610 |
+
"# past_page_numbers: page numbers that have already been retrieved\n",
|
| 611 |
+
"# question: a question\n",
|
| 612 |
+
"# options: options\n",
|
| 613 |
+
"\n",
|
| 614 |
+
"\n",
|
| 615 |
+
"\n",
|
| 616 |
+
"# Response/Answer\n",
|
| 617 |
+
"answer_prompt_template = \"\"\"\n",
|
| 618 |
+
"Read the following article and answer a multiple choice question.\n",
|
| 619 |
+
"For example, if (C) is correct, answer with \\\"Answer: (C) ...\\\"\n",
|
| 620 |
+
"\n",
|
| 621 |
+
"Article:\n",
|
| 622 |
+
"{concatenated_pages_and_gists}\n",
|
| 623 |
+
"\n",
|
| 624 |
+
"Question:\n",
|
| 625 |
+
"{question}\n",
|
| 626 |
+
"{options}\n",
|
| 627 |
+
"\n",
|
| 628 |
+
"\"\"\"\n",
|
| 629 |
+
"# concatenated_pages_and_gists: concatenated raw pages and gists\n",
|
| 630 |
+
"# question: a question\n",
|
| 631 |
+
"# options: options"
|
| 632 |
+
],
|
| 633 |
+
"metadata": {
|
| 634 |
+
"id": "PGvYRIpO3J3Y"
|
| 635 |
+
},
|
| 636 |
+
"execution_count": null,
|
| 637 |
+
"outputs": []
|
| 638 |
+
},
|
| 639 |
+
{
|
| 640 |
+
"cell_type": "code",
|
| 641 |
+
"source": [
|
| 642 |
+
"# @title The prompts we used for QMSum with PaLM 2-L\n",
|
| 643 |
+
"\n",
|
| 644 |
+
"\n",
|
| 645 |
+
"# Pagination\n",
|
| 646 |
+
"pagination_prompt_template = \"\"\"\n",
|
| 647 |
+
"You are given a passage that is taken from a larger meeting transcript.\n",
|
| 648 |
+
"There are some numbered labels between the paragraphs (like <0>) in the passage.\n",
|
| 649 |
+
"Please choose one label at a natural transition in the passage.\n",
|
| 650 |
+
"For example, the label can be at the end of a dialogue, the end of an argument, a change in the topic being discussed, etc.\n",
|
| 651 |
+
"Please respond with the label and explain your choice.\n",
|
| 652 |
+
"For example, if <57> is a natural transition, answer with \"Label: <57>\\n Because ...\"\n",
|
| 653 |
+
"\n",
|
| 654 |
+
"Passage:\n",
|
| 655 |
+
"\n",
|
| 656 |
+
"{preceding_text}\n",
|
| 657 |
+
"{passage_text}\n",
|
| 658 |
+
"{end_tag}\n",
|
| 659 |
+
"\n",
|
| 660 |
+
"\"\"\"\n",
|
| 661 |
+
"# preceding_text: a fraction of previous context\n",
|
| 662 |
+
"# passage_text: a chunk of text.\n",
|
| 663 |
+
"# end_tag: a string, whose value is \"\" if the text is at the end of the article, and otherwise \"\\n...\".\n",
|
| 664 |
+
"\n",
|
| 665 |
+
"\n",
|
| 666 |
+
"\n",
|
| 667 |
+
"# Gisting\n",
|
| 668 |
+
"gisting_prompt_template = \"\"\"\n",
|
| 669 |
+
"Please shorten the following passage.\n",
|
| 670 |
+
"Just give a shortened version. DO NOT explain your reasoning.\n",
|
| 671 |
+
"\n",
|
| 672 |
+
"Passage:\n",
|
| 673 |
+
"{page_text}\n",
|
| 674 |
+
"\n",
|
| 675 |
+
"\"\"\"\n",
|
| 676 |
+
"# page_text: a page of text\n",
|
| 677 |
+
"\n",
|
| 678 |
+
"\n",
|
| 679 |
+
"\n",
|
| 680 |
+
"# Parallel Look-up (ReadAgent-P, up to 2 pages)\n",
|
| 681 |
+
"parallel_lookup_prompt_template = \"\"\"\n",
|
| 682 |
+
"The following text is what you remember from reading a meeting transcript, followed by a question about the transcript.\n",
|
| 683 |
+
"You may read 1 or 2 pages of the transcript again to refresh your memory to prepare to answer the question.\n",
|
| 684 |
+
"Please respond with which page(s) you would like to read.\n",
|
| 685 |
+
"For example, if your would only like to read Page 8, respond with \"I want to look up Page [8] ...\"\n",
|
| 686 |
+
"If you would like to read Page 7 and 12, respond with \"I want to look up Page [7, 12] ...\".\n",
|
| 687 |
+
"Only select as many pages as you need, but no more than 2 pages.\n",
|
| 688 |
+
"Don't answer the question yet.\n",
|
| 689 |
+
"\n",
|
| 690 |
+
"Text:\n",
|
| 691 |
+
"{text}\n",
|
| 692 |
+
"End of text.\n",
|
| 693 |
+
"\n",
|
| 694 |
+
"Question:\n",
|
| 695 |
+
"{question}\n",
|
| 696 |
+
"\n",
|
| 697 |
+
"Which page(s) would you like to look up?\n",
|
| 698 |
+
"\"\"\"\n",
|
| 699 |
+
"# concatenated_gists: Concatenated gists\n",
|
| 700 |
+
"# question: a question\n",
|
| 701 |
+
"\n",
|
| 702 |
+
"\n",
|
| 703 |
+
"\n",
|
| 704 |
+
"# Sequential Look-up (ReadAgent-S)\n",
|
| 705 |
+
"sequential_lookup_prompt_template = \"\"\"\n",
|
| 706 |
+
"The following text is what you remember from reading a meeting transcript, followed by a question about the transcript.\n",
|
| 707 |
+
"You may read multiple pages of the transcript again to refresh your memory and prepare to answer the question.\n",
|
| 708 |
+
"Each page that you re-read can significantly improve your chance of answering the question correctly.\n",
|
| 709 |
+
"Please specify a SINGLE page you would like to read again or say \"STOP\".\n",
|
| 710 |
+
"To read a page again, respond with \"Page $PAGE_NUM\", replacing $PAGE_NUM with the target page number.\n",
|
| 711 |
+
"You can only specify a SINGLE page in your response at this time.\n",
|
| 712 |
+
"DO NOT select more pages if you don't need to.\n",
|
| 713 |
+
"To stop, simply say \"STOP\".\n",
|
| 714 |
+
"DO NOT answer the question in your response.\n",
|
| 715 |
+
"\n",
|
| 716 |
+
"Text:\n",
|
| 717 |
+
"{concatenated_gists}\n",
|
| 718 |
+
"End of text.\n",
|
| 719 |
+
"\n",
|
| 720 |
+
"Pages re-read already (DO NOT ask to read them again):\n",
|
| 721 |
+
"{past_page_numbers}\n",
|
| 722 |
+
"\n",
|
| 723 |
+
"Question:\n",
|
| 724 |
+
"{question}\n",
|
| 725 |
+
"\n",
|
| 726 |
+
"Specify a SINGLE page to read again, or say STOP:\n",
|
| 727 |
+
"\"\"\"\n",
|
| 728 |
+
"# concatenated_gists: concatenated gists\n",
|
| 729 |
+
"# past_page_numbers: page numbers that have already been retrieved\n",
|
| 730 |
+
"# question: a question\n",
|
| 731 |
+
"\n",
|
| 732 |
+
"\n",
|
| 733 |
+
"\n",
|
| 734 |
+
"# Response/Answer\n",
|
| 735 |
+
"answer_prompt_template = \"\"\"\n",
|
| 736 |
+
"Read the question and text below and then answer the question.\n",
|
| 737 |
+
"\n",
|
| 738 |
+
"Question:\n",
|
| 739 |
+
"{question}\n",
|
| 740 |
+
"\n",
|
| 741 |
+
"Text:\n",
|
| 742 |
+
"{concatenated_pages_and_gists}\n",
|
| 743 |
+
"End of Text.\n",
|
| 744 |
+
"\n",
|
| 745 |
+
"Answer the question based on the above passage and retrieved pages. Your answer should be short and concise.\n",
|
| 746 |
+
"\"\"\"\n",
|
| 747 |
+
"# question: a question\n",
|
| 748 |
+
"# concatenated_pages_and_gists: concatenated raw pages and gists"
|
| 749 |
+
],
|
| 750 |
+
"metadata": {
|
| 751 |
+
"id": "Ya48p13EhhhI"
|
| 752 |
+
},
|
| 753 |
+
"execution_count": null,
|
| 754 |
+
"outputs": []
|
| 755 |
+
},
|
| 756 |
+
{
|
| 757 |
+
"cell_type": "code",
|
| 758 |
+
"source": [
|
| 759 |
+
"# @title The prompts we used for NarrativeQA - Gutenburg with PaLM 2-L\n",
|
| 760 |
+
"\n",
|
| 761 |
+
"\n",
|
| 762 |
+
"# Pagination\n",
|
| 763 |
+
"pagination_prompt_template = \"\"\"\n",
|
| 764 |
+
"You are given a passage that is taken from a larger text (article, book, ...) and some numbered labels between the paragraphs in the passage.\n",
|
| 765 |
+
"Numbered label are in angeled brackets. For example, if the label number is 19, it shows as <19> in text.\n",
|
| 766 |
+
"Please choose one label that marks a major section break point.\n",
|
| 767 |
+
"Such points can be the beginning/end of a book, beginning/end of a chapter, end of a content table, a scene transition, end of a dialogue, etc.\n",
|
| 768 |
+
"If a point is chosen for the beginning of a book/chapter/etc and there is a title of the new book/chapter/etc, the break point must be chosen at a position right before the section number and title, not after.\n",
|
| 769 |
+
"\n",
|
| 770 |
+
"Please answer the break point label and explain.\n",
|
| 771 |
+
"For example, if <57> is a good point to break, answer with \\\"Breakpoint: <57> ...\\\"\n",
|
| 772 |
+
"\n",
|
| 773 |
+
"Text:\n",
|
| 774 |
+
"\n",
|
| 775 |
+
"{preceding_text}\n",
|
| 776 |
+
"{passage_text}\n",
|
| 777 |
+
"{end_tag}\n",
|
| 778 |
+
"\n",
|
| 779 |
+
"\"\"\"\n",
|
| 780 |
+
"# preceding_text: a fraction of previous context\n",
|
| 781 |
+
"# passage_text: a chunk of text.\n",
|
| 782 |
+
"# end_tag: a string, whose value is \"\" if the text is at the end of the article, and otherwise \"\\n...\".\n",
|
| 783 |
+
"\n",
|
| 784 |
+
"\n",
|
| 785 |
+
"\n",
|
| 786 |
+
"# Gisting\n",
|
| 787 |
+
"gisting_prompt_template = \"\"\"\n",
|
| 788 |
+
"Please shorten the following passage.\n",
|
| 789 |
+
"Just give me a shortened version. DO NOT explain your reason.\n",
|
| 790 |
+
"\n",
|
| 791 |
+
"Passage:\n",
|
| 792 |
+
"{page_text}\n",
|
| 793 |
+
"\n",
|
| 794 |
+
"\"\"\"\n",
|
| 795 |
+
"# page_text: a page of text\n",
|
| 796 |
+
"\n",
|
| 797 |
+
"\n",
|
| 798 |
+
"\n",
|
| 799 |
+
"# Parallel Look-up (ReadAgent-P, up to 2 pages)\n",
|
| 800 |
+
"parallel_lookup_prompt_template = \"\"\"\n",
|
| 801 |
+
"The following text is what you remembered from reading an article and a question related to it.\n",
|
| 802 |
+
"You may read 1 or 2 page(s) of the article again to refresh your memory to prepare yourselve for the question.\n",
|
| 803 |
+
"Please respond with which page(s) you would like to read in the order of importance, beginning with the most important page number.\n",
|
| 804 |
+
"For example, if your only need to read Page 8, respond with \\\"I want to look up Page [8] to ...\\\";\n",
|
| 805 |
+
"if your would like to read Page 12 and 7, respond with \\\"I want to look up Page [12, 7] to ...\\\";\n",
|
| 806 |
+
"DO NOT select more pages if you don't need to.\n",
|
| 807 |
+
"You don't need to answer the question yet.\n",
|
| 808 |
+
"\n",
|
| 809 |
+
"Text:\n",
|
| 810 |
+
"{concatenated_gists}\n",
|
| 811 |
+
"\n",
|
| 812 |
+
"Question:\n",
|
| 813 |
+
"{question}\n",
|
| 814 |
+
"\n",
|
| 815 |
+
"\"\"\"\n",
|
| 816 |
+
"# concatenated_gists: Concatenated gists\n",
|
| 817 |
+
"# question: a question\n",
|
| 818 |
+
"\n",
|
| 819 |
+
"\n",
|
| 820 |
+
"\n",
|
| 821 |
+
"# Sequential Look-up (ReadAgent-S)\n",
|
| 822 |
+
"sequential_lookup_prompt_template = \"\"\"\n",
|
| 823 |
+
"The following text is what you remember from reading a meeting transcript, followed by a question about the transcript.\n",
|
| 824 |
+
"You may read multiple pages of the transcript again to refresh your memory and prepare to answer the question.\n",
|
| 825 |
+
"Each page that you re-read can significantly improve your chance of answering the question correctly.\n",
|
| 826 |
+
"Please specify a SINGLE page you would like to read again or say \"STOP\".\n",
|
| 827 |
+
"To read a page again, respond with \"Page $PAGE_NUM\", replacing $PAGE_NUM with the target page number.\n",
|
| 828 |
+
"You can only specify a SINGLE page in your response at this time.\n",
|
| 829 |
+
"DO NOT select more pages if you don't need to.\n",
|
| 830 |
+
"To stop, simply say \"STOP\".\n",
|
| 831 |
+
"DO NOT answer the question in your response.\n",
|
| 832 |
+
"\n",
|
| 833 |
+
"Text:\n",
|
| 834 |
+
"{concatenated_gists}\n",
|
| 835 |
+
"End of text.\n",
|
| 836 |
+
"\n",
|
| 837 |
+
"Pages re-read already (DO NOT ask to read them again):\n",
|
| 838 |
+
"{past_page_numbers}\n",
|
| 839 |
+
"\n",
|
| 840 |
+
"Question:\n",
|
| 841 |
+
"{question}\n",
|
| 842 |
+
"\n",
|
| 843 |
+
"Specify a SINGLE page to read again, or say STOP:\n",
|
| 844 |
+
"\"\"\"\n",
|
| 845 |
+
"# concatenated_gists: concatenated gists\n",
|
| 846 |
+
"# past_page_numbers: page numbers that have already been retrieved\n",
|
| 847 |
+
"# question: a question\n",
|
| 848 |
+
"\n",
|
| 849 |
+
"\n",
|
| 850 |
+
"\n",
|
| 851 |
+
"# Response/Answer\n",
|
| 852 |
+
"answer_prompt_template = \"\"\"\n",
|
| 853 |
+
"{concatenated_pages_and_gists}\n",
|
| 854 |
+
"\n",
|
| 855 |
+
"Question:\n",
|
| 856 |
+
"{question}\n",
|
| 857 |
+
"\n",
|
| 858 |
+
"Answer the question based on the above passage and retrieved pages. Your answer should be short and concise.\n",
|
| 859 |
+
"\"\"\"\n",
|
| 860 |
+
"# concatenated_pages_and_gists: concatenated raw pages and gists\n",
|
| 861 |
+
"# question: a question"
|
| 862 |
+
],
|
| 863 |
+
"metadata": {
|
| 864 |
+
"id": "U210HZJuP4_u"
|
| 865 |
+
},
|
| 866 |
+
"execution_count": null,
|
| 867 |
+
"outputs": []
|
| 868 |
+
},
|
| 869 |
+
{
|
| 870 |
+
"cell_type": "code",
|
| 871 |
+
"source": [
|
| 872 |
+
"# @title The prompts we used for NarrativeQA - Movie Scripts with PaLM 2-L\n",
|
| 873 |
+
"\n",
|
| 874 |
+
"\n",
|
| 875 |
+
"# Pagination\n",
|
| 876 |
+
"pagination_prompt_template = \"\"\"\n",
|
| 877 |
+
"You are given a movie script and some numbered labels between the lines in the script.\n",
|
| 878 |
+
"Numbered label are in angeled brackets.\n",
|
| 879 |
+
"Please choose one label that it is natural to break reading. The label should be between <{start}> and <{end}>.\n",
|
| 880 |
+
"Such point can be scene transition, end of a dialogue, end of an argument, narrative transition, etc.\n",
|
| 881 |
+
"The answer should end with \"The break point is: <number>\", where the break point number is between angeled brackets.\n",
|
| 882 |
+
"\n",
|
| 883 |
+
"Script:\n",
|
| 884 |
+
"\n",
|
| 885 |
+
"{passage_text}\n",
|
| 886 |
+
"\"\"\"\n",
|
| 887 |
+
"# passage_text: a chunk of text.\n",
|
| 888 |
+
"\n",
|
| 889 |
+
"\n",
|
| 890 |
+
"\n",
|
| 891 |
+
"# Gisting\n",
|
| 892 |
+
"gisting_prompt_template = \"\"\"\n",
|
| 893 |
+
"Please shorten the following passage. The shortened passage should be in 128 tokens. Please refer to people with their full names whenever possible.\n",
|
| 894 |
+
"Just give me a shortened version. DO NOT explain your reason. If there is no meaning information in the passage, output \"I don't have enough information to shorten the passage.\"\n",
|
| 895 |
+
"\n",
|
| 896 |
+
"Passage:\n",
|
| 897 |
+
"{page_text}\n",
|
| 898 |
+
"\n",
|
| 899 |
+
"\"\"\"\n",
|
| 900 |
+
"# page_text: a page of text\n",
|
| 901 |
+
"\n",
|
| 902 |
+
"\n",
|
| 903 |
+
"\n",
|
| 904 |
+
"# Parallel Look-up (ReadAgent-P, up to 2 pages)\n",
|
| 905 |
+
"parallel_lookup_prompt_template = \"\"\"\n",
|
| 906 |
+
"The following text includes a summary of each page in a movie script, followed by a question about the script.\n",
|
| 907 |
+
"\n",
|
| 908 |
+
"Summary:\n",
|
| 909 |
+
"{concatenated_gists}\n",
|
| 910 |
+
"\n",
|
| 911 |
+
"Question:\n",
|
| 912 |
+
"{question}\n",
|
| 913 |
+
"\n",
|
| 914 |
+
"Based on the summary of each page, you may read the full details of 1 to 2 page(s) to obtain more information to answer the question.\n",
|
| 915 |
+
"Please respond with which page(s) you would like to read.\n",
|
| 916 |
+
"For example, if you only need to read Page X_0, the answer should end with \\\"I want to look up Page [X_0]\\\";\n",
|
| 917 |
+
"if you would like to read Page X_0 and X_1, the answer should end with \\\"I want to look up Page [X_0, X_1]\\\".\n",
|
| 918 |
+
"X_i above is a page index between 0 and {end}. DO NOT select more pages if you don't need to.\n",
|
| 919 |
+
"You don't need to answer the question yet.\n",
|
| 920 |
+
"\"\"\"\n",
|
| 921 |
+
"# concatenated_gists: Concatenated gists\n",
|
| 922 |
+
"# question: a question\n",
|
| 923 |
+
"\n",
|
| 924 |
+
"\n",
|
| 925 |
+
"\n",
|
| 926 |
+
"# Sequential Look-up (ReadAgent-S)\n",
|
| 927 |
+
"sequential_lookup_prompt_template = \"\"\"\n",
|
| 928 |
+
"The following text includes a summary of each page in a movie script, followed by a question about the script, and the previous answer based on the summary and already re-read pages.\n",
|
| 929 |
+
"\n",
|
| 930 |
+
"Summary:\n",
|
| 931 |
+
"{concatenated_gists}\n",
|
| 932 |
+
"\n",
|
| 933 |
+
"Pages re-read already (DO NOT ask to read them again):\n",
|
| 934 |
+
"{past_page_numbers}\n",
|
| 935 |
+
"\n",
|
| 936 |
+
"Question:\n",
|
| 937 |
+
"{question}\n",
|
| 938 |
+
"\n",
|
| 939 |
+
"Previous Answer:\n",
|
| 940 |
+
"{previous_answer}\n",
|
| 941 |
+
"\n",
|
| 942 |
+
"Based on the summary of each page, you may read the full details of multiple pages to obtain more information to answer the question.\n",
|
| 943 |
+
"To read a page again, respond with \"I want to look up Page $PAGE_NUM\", replacing $PAGE_NUM with the target page number.\n",
|
| 944 |
+
"PAGE_NUM is a page index between 0 and {end}, excluding {pages_reread}.\n",
|
| 945 |
+
"You can only specify a SINGLE page in your response at this time. The page should not be in the re-read pages.\n",
|
| 946 |
+
"To stop, simply say \"STOP\".\n",
|
| 947 |
+
"DO NOT answer the question in your response.\n",
|
| 948 |
+
"You don't need to answer the question yet.\n",
|
| 949 |
+
"\"\"\"\n",
|
| 950 |
+
"# concatenated_gists: concatenated gists\n",
|
| 951 |
+
"# past_page_numbers: (only after the first query) page numbers that have already been retrieved\n",
|
| 952 |
+
"# question: a question\n",
|
| 953 |
+
"# previous_answer: the previous answer given by the model with gists and previously retrieved raw pages\n",
|
| 954 |
+
"\n",
|
| 955 |
+
"\n",
|
| 956 |
+
"\n",
|
| 957 |
+
"# Response/Answer\n",
|
| 958 |
+
"answer_prompt_template = \"\"\"\n",
|
| 959 |
+
"{concatenated_pages_and_gists}\n",
|
| 960 |
+
"\n",
|
| 961 |
+
"Question:\n",
|
| 962 |
+
"{question}\n",
|
| 963 |
+
"\n",
|
| 964 |
+
"Answer the question based on the above passage and retrieved pages. Your answer should be short and concise.\n",
|
| 965 |
+
"\"\"\"\n",
|
| 966 |
+
"# concatenated_pages_and_gists: concatenated raw pages and gists\n",
|
| 967 |
+
"# question: a question"
|
| 968 |
+
],
|
| 969 |
+
"metadata": {
|
| 970 |
+
"id": "K1pw1NPasHPC"
|
| 971 |
+
},
|
| 972 |
+
"execution_count": null,
|
| 973 |
+
"outputs": []
|
| 974 |
+
}
|
| 975 |
+
]
|
| 976 |
+
}
|
baselines/read-agent/run_readagent_baseline.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ReadAgent baseline for the EvolV-Mem benchmark.
|
| 3 |
+
|
| 4 |
+
Adapts ReadAgent's 3-step pipeline for EvolV-Mem:
|
| 5 |
+
1. Pagination: each session = one page (natural segmentation, no LLM needed)
|
| 6 |
+
2. Gisting: use pre-computed session summaries (no LLM needed)
|
| 7 |
+
3. Look-up + Answer: LLM reads gists, selects pages to expand, then answers
|
| 8 |
+
|
| 9 |
+
Since ~1000 sessions don't fit in one prompt as gists, we pre-filter with SBert
|
| 10 |
+
to top-N sessions, then apply ReadAgent's look-up on those.
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python baselines/read-agent.github.io/run_readagent_baseline.py \
|
| 14 |
+
--in_file dataset/evolv_mem_v4.json \
|
| 15 |
+
--out_file output/readagent_qwen30b_v4.jsonl \
|
| 16 |
+
--summary_file dataset/all_session_summary.json \
|
| 17 |
+
--sessions_file dataset/all_sessions.json \
|
| 18 |
+
--profile_file metadata/generated_user_profile.json
|
| 19 |
+
|
| 20 |
+
Env vars:
|
| 21 |
+
VLLM_BASE_URL (default http://localhost:8000/v1)
|
| 22 |
+
VLLM_API_KEY (default EMPTY)
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import argparse
|
| 26 |
+
import json
|
| 27 |
+
import logging
|
| 28 |
+
import os
|
| 29 |
+
import re
|
| 30 |
+
import sys
|
| 31 |
+
import time
|
| 32 |
+
from collections import defaultdict
|
| 33 |
+
from typing import Dict, List, Optional
|
| 34 |
+
|
| 35 |
+
import numpy as np
|
| 36 |
+
from tqdm import tqdm
|
| 37 |
+
|
| 38 |
+
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
|
| 39 |
+
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
# vLLM LLM helper
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
MODEL_NAME = os.getenv("VLLM_MODEL_NAME", "Qwen/Qwen3-30B-A3B-Instruct-2507")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_llm_client():
|
| 48 |
+
from openai import OpenAI
|
| 49 |
+
return OpenAI(
|
| 50 |
+
base_url=os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1"),
|
| 51 |
+
api_key=os.getenv("VLLM_API_KEY", "EMPTY"),
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def llm_call(client, prompt: str, max_tokens: int = 4096, temperature: float = 0.0) -> str:
|
| 56 |
+
for attempt in range(6):
|
| 57 |
+
try:
|
| 58 |
+
response = client.chat.completions.create(
|
| 59 |
+
model=MODEL_NAME,
|
| 60 |
+
messages=[{"role": "user", "content": prompt}],
|
| 61 |
+
max_tokens=max_tokens,
|
| 62 |
+
temperature=temperature,
|
| 63 |
+
)
|
| 64 |
+
content = response.choices[0].message.content if response.choices else None
|
| 65 |
+
if content is None:
|
| 66 |
+
# NVIDIA endpoint occasionally returns null content; retry once with same prompt
|
| 67 |
+
wait = min(2 ** attempt * 2, 30)
|
| 68 |
+
print(f"[WARN] LLM returned None content (attempt {attempt+1}); retrying in {wait}s")
|
| 69 |
+
time.sleep(wait)
|
| 70 |
+
continue
|
| 71 |
+
return content.strip()
|
| 72 |
+
except Exception as e:
|
| 73 |
+
msg = str(e).lower()
|
| 74 |
+
if any(code in msg for code in ("429", "500", "503", "rate limit")):
|
| 75 |
+
wait = min(2 ** attempt * 5, 60)
|
| 76 |
+
print(f"[WARN] LLM retry {attempt+1}/6, sleeping {wait}s: {e}")
|
| 77 |
+
time.sleep(wait)
|
| 78 |
+
continue
|
| 79 |
+
print(f"[ERROR] LLM call failed: {e}")
|
| 80 |
+
raise
|
| 81 |
+
raise RuntimeError("LLM call failed after 6 retries")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ---------------------------------------------------------------------------
|
| 85 |
+
# ReadAgent prompts (adapted from the notebook for chat-history QA)
|
| 86 |
+
# ---------------------------------------------------------------------------
|
| 87 |
+
|
| 88 |
+
PROMPT_LOOKUP = """The following text contains gists (short summaries) of pages from a user's chat history. Each page is one conversation session.
|
| 89 |
+
You are also given a question about the user's chat history.
|
| 90 |
+
|
| 91 |
+
You may select up to {max_pages} page(s) to read in full to help answer the question.
|
| 92 |
+
Please respond with which page(s) you would like to read.
|
| 93 |
+
For example, if you only need to read Page 3, respond with "I want to look up Page [3] to ...";
|
| 94 |
+
if you would like to read Page 2 and 7, respond with "I want to look up Page [2, 7] to ...".
|
| 95 |
+
DO NOT select more pages than necessary.
|
| 96 |
+
DO NOT answer the question yet.
|
| 97 |
+
|
| 98 |
+
Text:
|
| 99 |
+
{concatenated_gists}
|
| 100 |
+
|
| 101 |
+
Question:
|
| 102 |
+
{question}
|
| 103 |
+
|
| 104 |
+
Current Date: {question_date}
|
| 105 |
+
|
| 106 |
+
Take a deep breath and tell me: Which page(s) would you like to read again?
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
PROMPT_ANSWER = """Read the following chat history and answer the question.
|
| 110 |
+
|
| 111 |
+
{profile_section}
|
| 112 |
+
|
| 113 |
+
Chat History:
|
| 114 |
+
{concatenated_pages_and_gists}
|
| 115 |
+
|
| 116 |
+
Current Date: {question_date}
|
| 117 |
+
Question: {question}
|
| 118 |
+
Answer:"""
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# ---------------------------------------------------------------------------
|
| 122 |
+
# Embedding pre-filter
|
| 123 |
+
# ---------------------------------------------------------------------------
|
| 124 |
+
|
| 125 |
+
def embed_and_filter(
|
| 126 |
+
question: str,
|
| 127 |
+
session_ids: List[str],
|
| 128 |
+
session_gists: List[str],
|
| 129 |
+
embedding_model,
|
| 130 |
+
top_k: int = 100,
|
| 131 |
+
) -> List[int]:
|
| 132 |
+
"""Return indices of top-k most relevant sessions by cosine similarity."""
|
| 133 |
+
if len(session_ids) <= top_k:
|
| 134 |
+
return list(range(len(session_ids)))
|
| 135 |
+
|
| 136 |
+
question_emb = embedding_model.encode(question)
|
| 137 |
+
gist_embs = embedding_model.encode(session_gists)
|
| 138 |
+
|
| 139 |
+
question_norm = question_emb / (np.linalg.norm(question_emb) + 1e-10)
|
| 140 |
+
gist_norms = gist_embs / (np.linalg.norm(gist_embs, axis=1, keepdims=True) + 1e-10)
|
| 141 |
+
similarities = gist_norms @ question_norm
|
| 142 |
+
|
| 143 |
+
top_indices = np.argsort(similarities)[::-1][:top_k].tolist()
|
| 144 |
+
# Sort by original order (chronological) for coherent reading
|
| 145 |
+
top_indices.sort()
|
| 146 |
+
return top_indices
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# ---------------------------------------------------------------------------
|
| 150 |
+
# ReadAgent Look-up: parse page selection from LLM response
|
| 151 |
+
# ---------------------------------------------------------------------------
|
| 152 |
+
|
| 153 |
+
def parse_lookup_pages(response: str, max_page: int) -> List[int]:
|
| 154 |
+
"""Parse page indices from ReadAgent look-up response like 'Page [2, 7, 12]'."""
|
| 155 |
+
try:
|
| 156 |
+
start = response.index('[')
|
| 157 |
+
end = response.index(']')
|
| 158 |
+
except ValueError:
|
| 159 |
+
return []
|
| 160 |
+
|
| 161 |
+
page_ids = []
|
| 162 |
+
for p in response[start + 1:end].split(','):
|
| 163 |
+
p = p.strip()
|
| 164 |
+
if p.isnumeric():
|
| 165 |
+
pid = int(p)
|
| 166 |
+
if 0 <= pid < max_page:
|
| 167 |
+
page_ids.append(pid)
|
| 168 |
+
return page_ids
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# ---------------------------------------------------------------------------
|
| 172 |
+
# Retrieval metrics
|
| 173 |
+
# ---------------------------------------------------------------------------
|
| 174 |
+
|
| 175 |
+
def evaluate_retrieval(recalled_docs, correct_docs):
|
| 176 |
+
recall_any = float(any(doc in recalled_docs for doc in correct_docs))
|
| 177 |
+
recall_all = float(all(doc in recalled_docs for doc in correct_docs))
|
| 178 |
+
return recall_any, recall_all
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def print_average_metrics(retrieval_metric_list):
|
| 182 |
+
metric_sums = defaultdict(float)
|
| 183 |
+
metric_counts = defaultdict(int)
|
| 184 |
+
for metric in retrieval_metric_list:
|
| 185 |
+
for k, v in metric.items():
|
| 186 |
+
metric_sums[k] += v
|
| 187 |
+
metric_counts[k] += 1
|
| 188 |
+
print(" Average retrieval metrics:")
|
| 189 |
+
for k in sorted(metric_sums):
|
| 190 |
+
avg = metric_sums[k] / metric_counts[k]
|
| 191 |
+
print(f" {k}: {avg:.4f}")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# ---------------------------------------------------------------------------
|
| 195 |
+
# Main
|
| 196 |
+
# ---------------------------------------------------------------------------
|
| 197 |
+
|
| 198 |
+
def main():
|
| 199 |
+
parser = argparse.ArgumentParser(description="ReadAgent baseline for EvolV-Mem")
|
| 200 |
+
parser.add_argument("--in_file", type=str, required=True,
|
| 201 |
+
help="Path to evolv_mem_v4.json")
|
| 202 |
+
parser.add_argument("--out_file", type=str, required=True,
|
| 203 |
+
help="Output JSONL file")
|
| 204 |
+
parser.add_argument("--summary_file", type=str, required=True,
|
| 205 |
+
help="Path to all_session_summary.json")
|
| 206 |
+
parser.add_argument("--sessions_file", type=str, required=True,
|
| 207 |
+
help="Path to all_sessions.json")
|
| 208 |
+
parser.add_argument("--profile_file", type=str, default=None,
|
| 209 |
+
help="Path to generated_user_profile.json")
|
| 210 |
+
parser.add_argument("--embedding_model", type=str,
|
| 211 |
+
default="sentence-transformers/multi-qa-mpnet-base-cos-v1",
|
| 212 |
+
help="SentenceTransformer model for pre-filtering")
|
| 213 |
+
# ReadAgent params
|
| 214 |
+
parser.add_argument("--gist_top_k", type=int, default=100,
|
| 215 |
+
help="Number of sessions to keep after embedding pre-filter (default 100)")
|
| 216 |
+
parser.add_argument("--max_lookup_pages", type=int, default=10,
|
| 217 |
+
help="Max pages the LLM can select for full reading (default 10)")
|
| 218 |
+
# Limit
|
| 219 |
+
parser.add_argument("--limit", type=int, default=None,
|
| 220 |
+
help="Process only the first N questions")
|
| 221 |
+
args = parser.parse_args()
|
| 222 |
+
|
| 223 |
+
# -----------------------------------------------------------------------
|
| 224 |
+
# Load data
|
| 225 |
+
# -----------------------------------------------------------------------
|
| 226 |
+
print(f"Loading benchmark from {args.in_file} ...")
|
| 227 |
+
with open(args.in_file) as f:
|
| 228 |
+
benchmark = json.load(f)
|
| 229 |
+
if args.limit:
|
| 230 |
+
benchmark = benchmark[:args.limit]
|
| 231 |
+
print(f" {len(benchmark)} questions loaded.")
|
| 232 |
+
|
| 233 |
+
print(f"Loading session summaries from {args.summary_file} ...")
|
| 234 |
+
with open(args.summary_file) as f:
|
| 235 |
+
summaries = json.load(f)
|
| 236 |
+
print(f" {len(summaries)} session summaries loaded.")
|
| 237 |
+
|
| 238 |
+
print(f"Loading sessions from {args.sessions_file} ...")
|
| 239 |
+
with open(args.sessions_file) as f:
|
| 240 |
+
all_sessions = json.load(f)
|
| 241 |
+
print(f" {len(all_sessions)} sessions loaded.")
|
| 242 |
+
|
| 243 |
+
profiles = {}
|
| 244 |
+
if args.profile_file and os.path.exists(args.profile_file):
|
| 245 |
+
print(f"Loading user profiles from {args.profile_file} ...")
|
| 246 |
+
with open(args.profile_file) as f:
|
| 247 |
+
profiles = json.load(f)
|
| 248 |
+
print(f" {len(profiles)} profiles loaded.")
|
| 249 |
+
|
| 250 |
+
# -----------------------------------------------------------------------
|
| 251 |
+
# Resume support
|
| 252 |
+
# -----------------------------------------------------------------------
|
| 253 |
+
existing_qids = set()
|
| 254 |
+
if os.path.exists(args.out_file):
|
| 255 |
+
with open(args.out_file) as f:
|
| 256 |
+
for line in f:
|
| 257 |
+
line = line.strip()
|
| 258 |
+
if line:
|
| 259 |
+
existing_qids.add(json.loads(line)["question_id"])
|
| 260 |
+
print(f" Resuming: {len(existing_qids)} questions already processed.")
|
| 261 |
+
|
| 262 |
+
# -----------------------------------------------------------------------
|
| 263 |
+
# Initialize models
|
| 264 |
+
# -----------------------------------------------------------------------
|
| 265 |
+
print("Initializing embedding model ...")
|
| 266 |
+
from sentence_transformers import SentenceTransformer
|
| 267 |
+
embedding_model = SentenceTransformer(args.embedding_model)
|
| 268 |
+
|
| 269 |
+
print("Initializing vLLM client ...")
|
| 270 |
+
client = get_llm_client()
|
| 271 |
+
|
| 272 |
+
# -----------------------------------------------------------------------
|
| 273 |
+
# Process questions
|
| 274 |
+
# -----------------------------------------------------------------------
|
| 275 |
+
retrieval_metric_list = []
|
| 276 |
+
out_f = open(args.out_file, "a")
|
| 277 |
+
|
| 278 |
+
for di, entry in enumerate(tqdm(benchmark, desc="ReadAgent baseline")):
|
| 279 |
+
qid = entry["question_id"]
|
| 280 |
+
question = entry["question"]
|
| 281 |
+
question_date = entry["question_date"]
|
| 282 |
+
|
| 283 |
+
if qid in existing_qids:
|
| 284 |
+
continue
|
| 285 |
+
|
| 286 |
+
try:
|
| 287 |
+
haystack_sids = entry["haystack_session_ids"]
|
| 288 |
+
haystack_dates = entry["haystack_dates"]
|
| 289 |
+
|
| 290 |
+
# === Step 1 & 2: Pagination + Gisting (free with cached data) ===
|
| 291 |
+
# Each session = one page; session summary = gist
|
| 292 |
+
page_sids = []
|
| 293 |
+
page_dates = []
|
| 294 |
+
page_gists = []
|
| 295 |
+
for sid, date_str in zip(haystack_sids, haystack_dates):
|
| 296 |
+
summary_data = summaries.get(sid)
|
| 297 |
+
if summary_data is None:
|
| 298 |
+
continue
|
| 299 |
+
text = summary_data.get("session_summary", "")
|
| 300 |
+
if not text:
|
| 301 |
+
turn_sums = summary_data.get("turn_summaries", [])
|
| 302 |
+
text = " ".join(turn_sums) if turn_sums else ""
|
| 303 |
+
if not text:
|
| 304 |
+
continue
|
| 305 |
+
page_sids.append(sid)
|
| 306 |
+
page_dates.append(date_str)
|
| 307 |
+
page_gists.append(text)
|
| 308 |
+
|
| 309 |
+
if not page_gists:
|
| 310 |
+
result = {
|
| 311 |
+
"q_idx": di, "question_id": qid,
|
| 312 |
+
"hypothesis": "Insufficient information to answer.",
|
| 313 |
+
"n_pages": 0,
|
| 314 |
+
}
|
| 315 |
+
print(json.dumps(result), file=out_f, flush=True)
|
| 316 |
+
continue
|
| 317 |
+
|
| 318 |
+
# === Embedding pre-filter to top-K gists ===
|
| 319 |
+
filtered_indices = embed_and_filter(
|
| 320 |
+
question, page_sids, page_gists, embedding_model, top_k=args.gist_top_k
|
| 321 |
+
)
|
| 322 |
+
filtered_sids = [page_sids[i] for i in filtered_indices]
|
| 323 |
+
filtered_dates = [page_dates[i] for i in filtered_indices]
|
| 324 |
+
filtered_gists = [page_gists[i] for i in filtered_indices]
|
| 325 |
+
|
| 326 |
+
# Build gist text with page numbers
|
| 327 |
+
gist_lines = []
|
| 328 |
+
for local_idx, (sid, date, gist) in enumerate(
|
| 329 |
+
zip(filtered_sids, filtered_dates, filtered_gists)
|
| 330 |
+
):
|
| 331 |
+
gist_lines.append(f"<Page {local_idx}> [Date: {date}]\n{gist}")
|
| 332 |
+
concatenated_gists = "\n\n".join(gist_lines)
|
| 333 |
+
|
| 334 |
+
# === Step 3a: ReadAgent Look-Up ===
|
| 335 |
+
lookup_prompt = PROMPT_LOOKUP.format(
|
| 336 |
+
max_pages=args.max_lookup_pages,
|
| 337 |
+
concatenated_gists=concatenated_gists,
|
| 338 |
+
question=question,
|
| 339 |
+
question_date=question_date,
|
| 340 |
+
)
|
| 341 |
+
lookup_response = llm_call(client, lookup_prompt, max_tokens=4096, temperature=0.0)
|
| 342 |
+
selected_page_ids = parse_lookup_pages(lookup_response, len(filtered_sids))
|
| 343 |
+
|
| 344 |
+
print(f" [{di}] Look-up selected pages: {selected_page_ids} "
|
| 345 |
+
f"(out of {len(filtered_sids)} gists)")
|
| 346 |
+
|
| 347 |
+
# === Step 3b: Expand selected pages with full session content ===
|
| 348 |
+
expanded_lines = []
|
| 349 |
+
retrieved_session_ids = []
|
| 350 |
+
for local_idx, (sid, date, gist) in enumerate(
|
| 351 |
+
zip(filtered_sids, filtered_dates, filtered_gists)
|
| 352 |
+
):
|
| 353 |
+
if local_idx in selected_page_ids:
|
| 354 |
+
# Expand: use full session content
|
| 355 |
+
session_turns = all_sessions.get(sid, [])
|
| 356 |
+
full_text = "\n".join(
|
| 357 |
+
f"{t.get('role', 'user')}: {t.get('content', '')}"
|
| 358 |
+
for t in session_turns
|
| 359 |
+
)
|
| 360 |
+
expanded_lines.append(
|
| 361 |
+
f"<Page {local_idx}> [Session {sid} | Date: {date}] [FULL]\n{full_text}"
|
| 362 |
+
)
|
| 363 |
+
retrieved_session_ids.append(sid)
|
| 364 |
+
else:
|
| 365 |
+
# Keep gist
|
| 366 |
+
expanded_lines.append(
|
| 367 |
+
f"<Page {local_idx}> [Date: {date}] [GIST]\n{gist}"
|
| 368 |
+
)
|
| 369 |
+
concatenated_pages_and_gists = "\n\n".join(expanded_lines)
|
| 370 |
+
|
| 371 |
+
# === Step 3c: Answer ===
|
| 372 |
+
user_id = qid.split("_q_")[0] if "_q_" in qid else qid
|
| 373 |
+
user_profile = profiles.get(user_id, None)
|
| 374 |
+
profile_section = f"User Profile:\n{user_profile}" if user_profile else ""
|
| 375 |
+
|
| 376 |
+
answer_prompt = PROMPT_ANSWER.format(
|
| 377 |
+
profile_section=profile_section,
|
| 378 |
+
concatenated_pages_and_gists=concatenated_pages_and_gists,
|
| 379 |
+
question=question,
|
| 380 |
+
question_date=question_date,
|
| 381 |
+
)
|
| 382 |
+
answer = llm_call(client, answer_prompt, max_tokens=8192, temperature=0.0)
|
| 383 |
+
|
| 384 |
+
# === Retrieval metrics ===
|
| 385 |
+
answer_session_ids = entry.get("answer_session_ids", [])
|
| 386 |
+
retrieval_metric = {}
|
| 387 |
+
if answer_session_ids and retrieved_session_ids:
|
| 388 |
+
for topk in [5, 10, 20, 30]:
|
| 389 |
+
r_any, r_all = evaluate_retrieval(
|
| 390 |
+
retrieved_session_ids[:topk], answer_session_ids
|
| 391 |
+
)
|
| 392 |
+
retrieval_metric[f"recall_any@{topk}"] = r_any
|
| 393 |
+
retrieval_metric[f"recall_all@{topk}"] = r_all
|
| 394 |
+
retrieval_metric_list.append(retrieval_metric)
|
| 395 |
+
print_average_metrics(retrieval_metric_list)
|
| 396 |
+
|
| 397 |
+
# === Output ===
|
| 398 |
+
result = {
|
| 399 |
+
"q_idx": di,
|
| 400 |
+
"question_id": qid,
|
| 401 |
+
"hypothesis": answer,
|
| 402 |
+
"n_pages_total": len(page_gists),
|
| 403 |
+
"n_pages_filtered": len(filtered_sids),
|
| 404 |
+
"n_pages_expanded": len(selected_page_ids),
|
| 405 |
+
"retrieved_session_ids": retrieved_session_ids,
|
| 406 |
+
"retrieval_metric": retrieval_metric,
|
| 407 |
+
}
|
| 408 |
+
print(json.dumps(result), file=out_f, flush=True)
|
| 409 |
+
|
| 410 |
+
print(f" [{di}] Q: {question[:100]}...")
|
| 411 |
+
print(f" [{di}] A: {answer[:200]}...")
|
| 412 |
+
|
| 413 |
+
except Exception as e:
|
| 414 |
+
print(f"[ERROR] q_idx={di} qid={qid} failed: {e}", flush=True)
|
| 415 |
+
import traceback
|
| 416 |
+
traceback.print_exc()
|
| 417 |
+
continue
|
| 418 |
+
|
| 419 |
+
out_f.close()
|
| 420 |
+
print(f"\nDone. Results saved to {args.out_file}")
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
if __name__ == "__main__":
|
| 424 |
+
main()
|
evaluate_qa.py
ADDED
|
@@ -0,0 +1,916 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from openai import OpenAI
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from openai import AzureOpenAI
|
| 13 |
+
from azure.identity import (
|
| 14 |
+
AzureCliCredential,
|
| 15 |
+
ChainedTokenCredential,
|
| 16 |
+
ManagedIdentityCredential,
|
| 17 |
+
get_bearer_token_provider,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
AZURE_OAUTH_SCOPE = os.environ.get("AZURE_OAUTH_SCOPE", "")
|
| 21 |
+
if AZURE_OAUTH_SCOPE:
|
| 22 |
+
credential = get_bearer_token_provider(
|
| 23 |
+
ChainedTokenCredential(
|
| 24 |
+
AzureCliCredential(),
|
| 25 |
+
ManagedIdentityCredential(),
|
| 26 |
+
),
|
| 27 |
+
AZURE_OAUTH_SCOPE,
|
| 28 |
+
)
|
| 29 |
+
else:
|
| 30 |
+
credential = None
|
| 31 |
+
except ImportError:
|
| 32 |
+
AzureOpenAI = None
|
| 33 |
+
credential = None
|
| 34 |
+
|
| 35 |
+
from model_zoo import model_zoo
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# Azure OpenAI endpoint (set AZURE_OPENAI_ENDPOINT env var to your deployment URL).
|
| 39 |
+
endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT", "")
|
| 40 |
+
# OpenAI-compatible LiteLLM proxy URL (set LITELLM_BASE_URL env var to your proxy).
|
| 41 |
+
TRITONAI_BASE_URL = os.environ.get("LITELLM_BASE_URL", "")
|
| 42 |
+
|
| 43 |
+
ATOMIC_PROMPT_VERSION = "atomic-v1"
|
| 44 |
+
LEGACY_PROMPT_VERSION = "binary-v0"
|
| 45 |
+
ATOM_SCORES = {
|
| 46 |
+
"correct": 1.0,
|
| 47 |
+
"partially_correct": 0.5,
|
| 48 |
+
"missing": 0.0,
|
| 49 |
+
"incorrect": 0.0,
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _retryable_status(e) -> "int | None":
|
| 54 |
+
status = getattr(e, "status_code", None) or getattr(e, "http_status", None)
|
| 55 |
+
if status in (429, 500, 503, 403):
|
| 56 |
+
return status
|
| 57 |
+
resp = getattr(e, "response", None)
|
| 58 |
+
if resp is not None and getattr(resp, "status_code", None) in (429, 500, 503):
|
| 59 |
+
return resp.status_code
|
| 60 |
+
msg = str(e).lower()
|
| 61 |
+
if "429" in msg or "rate limit" in msg:
|
| 62 |
+
return 429
|
| 63 |
+
if "500" in msg or "internal server error" in msg:
|
| 64 |
+
return 500
|
| 65 |
+
if "503" in msg or "api configuration unavailable" in msg:
|
| 66 |
+
return 503
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def parse_json_object(text):
|
| 71 |
+
text = (text or "").strip()
|
| 72 |
+
if text.startswith("```"):
|
| 73 |
+
text = re.sub(r"^```(?:json)?", "", text).strip()
|
| 74 |
+
text = re.sub(r"```$", "", text).strip()
|
| 75 |
+
try:
|
| 76 |
+
return json.loads(text)
|
| 77 |
+
except json.JSONDecodeError:
|
| 78 |
+
start = text.find("{")
|
| 79 |
+
end = text.rfind("}") + 1
|
| 80 |
+
if start >= 0 and end > start:
|
| 81 |
+
return json.loads(text[start:end])
|
| 82 |
+
raise
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def sanitize_model_name(name):
|
| 86 |
+
return re.sub(r"[^A-Za-z0-9_.-]+", "_", name)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def read_json_or_jsonl(path):
|
| 90 |
+
try:
|
| 91 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 92 |
+
return json.load(f)
|
| 93 |
+
except json.JSONDecodeError:
|
| 94 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 95 |
+
return [json.loads(line) for line in f if line.strip()]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def read_existing_jsonl(path):
|
| 99 |
+
if not path or not os.path.exists(path):
|
| 100 |
+
return {}
|
| 101 |
+
rows = {}
|
| 102 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 103 |
+
for line in f:
|
| 104 |
+
if not line.strip():
|
| 105 |
+
continue
|
| 106 |
+
obj = json.loads(line)
|
| 107 |
+
if "question_id" in obj:
|
| 108 |
+
rows[obj["question_id"]] = obj
|
| 109 |
+
return rows
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def write_json(path, obj):
|
| 113 |
+
tmp_path = path + ".tmp"
|
| 114 |
+
with open(tmp_path, "w", encoding="utf-8") as f:
|
| 115 |
+
json.dump(obj, f, ensure_ascii=False, indent=2)
|
| 116 |
+
f.write("\n")
|
| 117 |
+
os.replace(tmp_path, path)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def append_jsonl(path, row):
|
| 121 |
+
with open(path, "a", encoding="utf-8") as f:
|
| 122 |
+
print(json.dumps(row, ensure_ascii=False), file=f, flush=True)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def default_result_file(hyp_file, metric_model, eval_mode):
|
| 126 |
+
if eval_mode == "legacy":
|
| 127 |
+
return f"{hyp_file}.eval-results-{metric_model}"
|
| 128 |
+
model_tag = sanitize_model_name(metric_model)
|
| 129 |
+
return f"{hyp_file}.eval-results-{model_tag}-{ATOMIC_PROMPT_VERSION}-{eval_mode}"
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def default_rubric_file(ref_file):
|
| 133 |
+
return f"{ref_file}.{ATOMIC_PROMPT_VERSION}.rubric.json"
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def load_rubric_file(path, ref_file):
|
| 137 |
+
if not path or not os.path.exists(path):
|
| 138 |
+
return {
|
| 139 |
+
"prompt_version": ATOMIC_PROMPT_VERSION,
|
| 140 |
+
"source_ref_file": ref_file,
|
| 141 |
+
"rubrics": {},
|
| 142 |
+
}
|
| 143 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 144 |
+
data = json.load(f)
|
| 145 |
+
if "rubrics" in data:
|
| 146 |
+
data.setdefault("prompt_version", ATOMIC_PROMPT_VERSION)
|
| 147 |
+
data.setdefault("source_ref_file", ref_file)
|
| 148 |
+
return data
|
| 149 |
+
return {
|
| 150 |
+
"prompt_version": ATOMIC_PROMPT_VERSION,
|
| 151 |
+
"source_ref_file": ref_file,
|
| 152 |
+
"rubrics": data,
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def question_type_guidance(task):
|
| 157 |
+
if task in ["Information Absence"]:
|
| 158 |
+
return (
|
| 159 |
+
"The correct answer is that the information is unavailable, absent, "
|
| 160 |
+
"not yet known, or not supported. A response that gives a concrete "
|
| 161 |
+
"answer instead of abstaining is incorrect."
|
| 162 |
+
)
|
| 163 |
+
if task in ["Aggregation", "single-session-user", "single-session-assistant", "multi-session"]:
|
| 164 |
+
return (
|
| 165 |
+
"Check every requested item, count, list member, and named fact. "
|
| 166 |
+
"Exact counts and required list coverage matter. Extra material "
|
| 167 |
+
"facts that change the answer should be flagged."
|
| 168 |
+
)
|
| 169 |
+
if task in ["Aggregation + Temporal"]:
|
| 170 |
+
return (
|
| 171 |
+
"Check both the aggregated facts and their time/order associations. "
|
| 172 |
+
"An answer can name the right item but still be wrong if the timing, "
|
| 173 |
+
"ordering, before/after relation, or year is wrong."
|
| 174 |
+
)
|
| 175 |
+
if task in ["Temporal Reasoning", "temporal-reasoning"]:
|
| 176 |
+
return (
|
| 177 |
+
"Check the specific time, date, year, sequence, duration, or temporal "
|
| 178 |
+
"relationship asked for. Accept +/-1 only for day/week/month duration "
|
| 179 |
+
"counts, not for years, event identity, or ordering."
|
| 180 |
+
)
|
| 181 |
+
if task in ["Knowledge Update", "knowledge-update"]:
|
| 182 |
+
return (
|
| 183 |
+
"Check the current or most recent state. Historical context is fine "
|
| 184 |
+
"only if the final/current state is clearly correct. Outdated states "
|
| 185 |
+
"presented as current are incorrect."
|
| 186 |
+
)
|
| 187 |
+
if task == "single-session-preference":
|
| 188 |
+
return (
|
| 189 |
+
"Check whether the response recalls and applies the stated preference. "
|
| 190 |
+
"Do not require unnecessary verbosity, but contradictions are incorrect."
|
| 191 |
+
)
|
| 192 |
+
return "Check whether the model response satisfies all required facts in the reference answer."
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def get_anscheck_prompt(task, question, answer, response):
|
| 196 |
+
"""Legacy binary yes/no prompt kept for backward compatibility."""
|
| 197 |
+
if task in ["Information Absence"]:
|
| 198 |
+
template = """You are evaluating whether a model correctly identified that a question is unanswerable given the available personal chat history.
|
| 199 |
+
|
| 200 |
+
Question: {question}
|
| 201 |
+
|
| 202 |
+
Explanation of why it is unanswerable: {answer}
|
| 203 |
+
|
| 204 |
+
Model Response: {response}
|
| 205 |
+
|
| 206 |
+
Evaluation criteria:
|
| 207 |
+
- CORRECT if the model explicitly states that the information is not available, insufficient, or that the question cannot be answered from the provided context.
|
| 208 |
+
- INCORRECT if the model fabricates an answer or fails to acknowledge the unanswerable nature of the question.
|
| 209 |
+
- The model does not need to use the exact word "unanswerable" -- expressing uncertainty or lack of information is sufficient.
|
| 210 |
+
|
| 211 |
+
Briefly explain your reasoning (1-2 sentences), then on the last line write only: yes or no"""
|
| 212 |
+
elif task in ["Aggregation", "single-session-user", "single-session-assistant", "multi-session"]:
|
| 213 |
+
template = """You are evaluating whether a model correctly answered a question that requires aggregating specific facts from a user's personal chat history.
|
| 214 |
+
|
| 215 |
+
Question: {question}
|
| 216 |
+
|
| 217 |
+
Reference Answer: {answer}
|
| 218 |
+
|
| 219 |
+
Model Response: {response}
|
| 220 |
+
|
| 221 |
+
Evaluation criteria:
|
| 222 |
+
- CORRECT if the response identifies all key items, facts, or counts present in the reference answer, even if phrased differently or with added context.
|
| 223 |
+
- INCORRECT if the response:
|
| 224 |
+
- States a wrong count (e.g., says "5" when the answer is "3")
|
| 225 |
+
- Omits one or more key items/facts listed in the reference
|
| 226 |
+
- Lists mostly wrong items even if the count is right
|
| 227 |
+
- Partial answers that cover only a subset of required items are INCORRECT.
|
| 228 |
+
- Verbose responses are acceptable as long as all reference items are present within them.
|
| 229 |
+
- If the response contains correct items but also lists additional plausible-sounding but unverified items beyond the reference, this does NOT make it incorrect -- evaluate only whether the reference items are covered.
|
| 230 |
+
- Semantic equivalence counts as correct (e.g., "RSI" = "Repetitive Strain Injury").
|
| 231 |
+
|
| 232 |
+
Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no"""
|
| 233 |
+
elif task in ["Aggregation + Temporal"]:
|
| 234 |
+
template = """You are evaluating whether a model correctly answered a question that requires both aggregating facts AND reasoning about their temporal order or time associations from a user's personal chat history.
|
| 235 |
+
|
| 236 |
+
Question: {question}
|
| 237 |
+
|
| 238 |
+
Reference Answer: {answer}
|
| 239 |
+
|
| 240 |
+
Model Response: {response}
|
| 241 |
+
|
| 242 |
+
Evaluation criteria:
|
| 243 |
+
- CORRECT if the response captures both:
|
| 244 |
+
(a) all key events/facts listed in the reference, and
|
| 245 |
+
(b) their correct temporal associations (ordering, time periods, or "before/after" relationships).
|
| 246 |
+
- INCORRECT if the response:
|
| 247 |
+
- Omits one or more key events or facts from the reference
|
| 248 |
+
- Gets the temporal ordering or time associations wrong
|
| 249 |
+
- Captures only the content without the temporal aspects, or vice versa
|
| 250 |
+
- Responses that describe the correct progression/sequence in different words are acceptable.
|
| 251 |
+
- Partial answers covering only some events or ignoring time aspects are INCORRECT.
|
| 252 |
+
- Minor wording differences or additional explanatory context are acceptable.
|
| 253 |
+
|
| 254 |
+
Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no"""
|
| 255 |
+
elif task in ["Temporal Reasoning", "temporal-reasoning"]:
|
| 256 |
+
template = """You are evaluating whether a model correctly answered a question about temporal relationships in a user's personal chat history.
|
| 257 |
+
|
| 258 |
+
Question: {question}
|
| 259 |
+
|
| 260 |
+
Reference Answer: {answer}
|
| 261 |
+
|
| 262 |
+
Model Response: {response}
|
| 263 |
+
|
| 264 |
+
Evaluation criteria:
|
| 265 |
+
- CORRECT if the response correctly identifies the specific time, date, year, sequence, or temporal relationship asked about.
|
| 266 |
+
- INCORRECT if the response states a wrong year, wrong sequence, wrong temporal relationship, or misidentifies which event came first/last.
|
| 267 |
+
- Off-by-one tolerance: if the question asks for a count of days, weeks, or months, accept answers that differ by +/-1. This tolerance does NOT apply to years or to identifying specific events/artifacts.
|
| 268 |
+
- Responses that correctly identify the fact but with verbose context are acceptable.
|
| 269 |
+
- If the response hedges but still states the correct answer, it is correct.
|
| 270 |
+
|
| 271 |
+
Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no"""
|
| 272 |
+
elif task in ["Knowledge Update", "knowledge-update"]:
|
| 273 |
+
template = """You are evaluating whether a model correctly answered a question about the most recent or current state of something that changed over time in a user's personal chat history.
|
| 274 |
+
|
| 275 |
+
Question: {question}
|
| 276 |
+
|
| 277 |
+
Reference Answer: {answer}
|
| 278 |
+
|
| 279 |
+
Model Response: {response}
|
| 280 |
+
|
| 281 |
+
Evaluation criteria:
|
| 282 |
+
- CORRECT if the response correctly identifies the most recent/current state as described in the reference answer.
|
| 283 |
+
- The response may include earlier historical states as context -- this is acceptable as long as the current/final state is correctly identified and clearly stated.
|
| 284 |
+
- INCORRECT if the response:
|
| 285 |
+
- States an outdated/superseded state as the current one
|
| 286 |
+
- Omits the current state entirely
|
| 287 |
+
- Correctly describes history but draws the wrong conclusion about what the current state is
|
| 288 |
+
- Semantic equivalence counts (e.g., "flexitarian" and "semi-vegetarian diet with occasional meat" are equivalent if contextually clear).
|
| 289 |
+
|
| 290 |
+
Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no"""
|
| 291 |
+
elif task == "single-session-preference":
|
| 292 |
+
template = """You are evaluating whether a model correctly answered a personalized question based on a user's stated preferences from their chat history.
|
| 293 |
+
|
| 294 |
+
Question: {question}
|
| 295 |
+
|
| 296 |
+
Reference Rubric: {answer}
|
| 297 |
+
|
| 298 |
+
Model Response: {response}
|
| 299 |
+
|
| 300 |
+
Evaluation criteria:
|
| 301 |
+
- CORRECT if the response recalls and applies the user's personal preferences correctly, even if not covering every point in the rubric.
|
| 302 |
+
- INCORRECT if the response ignores, contradicts, or misremembers the user's preferences.
|
| 303 |
+
|
| 304 |
+
Briefly explain your reasoning (1-2 sentences), then on the last line write only: yes or no"""
|
| 305 |
+
else:
|
| 306 |
+
template = """You are evaluating whether a model's response correctly answers a question based on a user's personal chat history.
|
| 307 |
+
|
| 308 |
+
Question: {question}
|
| 309 |
+
|
| 310 |
+
Reference Answer: {answer}
|
| 311 |
+
|
| 312 |
+
Model Response: {response}
|
| 313 |
+
|
| 314 |
+
Is the response correct? It is correct if it contains all key information from the reference answer, even if phrased differently.
|
| 315 |
+
|
| 316 |
+
Briefly explain your reasoning (1-2 sentences), then on the last line write only: yes or no"""
|
| 317 |
+
return template.format(question=question, answer=answer, response=response)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def build_rubric_prompt(task, question, answer):
|
| 321 |
+
return f"""You are creating an atomic grading rubric for an open-ended QA benchmark.
|
| 322 |
+
|
| 323 |
+
Question type: {task}
|
| 324 |
+
Question:
|
| 325 |
+
{question}
|
| 326 |
+
|
| 327 |
+
Reference answer:
|
| 328 |
+
{answer}
|
| 329 |
+
|
| 330 |
+
Question-type guidance:
|
| 331 |
+
{question_type_guidance(task)}
|
| 332 |
+
|
| 333 |
+
Decompose the reference answer into the smallest independently checkable requirements needed to answer the question.
|
| 334 |
+
|
| 335 |
+
Rules:
|
| 336 |
+
- Each atom should be a single required answer unit: an entity, count, date/year, order relation, current-state conclusion, or abstention requirement.
|
| 337 |
+
- If an entity and its temporal relation are inseparable for correctness, keep them in the same atom.
|
| 338 |
+
- For list/count questions, include an atom for the exact count when the question asks "how many", and atoms for each required listed item when the item identities matter.
|
| 339 |
+
- For Information Absence, usually use one atom requiring the response to clearly state that the information is unavailable/insufficient/not discussed, and add a strict note that concrete fabricated answers are wrong.
|
| 340 |
+
- Do not include supporting evidence requirements or session IDs unless the question explicitly asks for them.
|
| 341 |
+
- Weights should normally be 1.0. Use a higher weight only when one atom is clearly the main answer and other atoms are minor.
|
| 342 |
+
|
| 343 |
+
Return JSON only with this schema:
|
| 344 |
+
{{
|
| 345 |
+
"required_atoms": [
|
| 346 |
+
{{
|
| 347 |
+
"id": "a1",
|
| 348 |
+
"requirement": "short, specific grading requirement",
|
| 349 |
+
"weight": 1.0
|
| 350 |
+
}}
|
| 351 |
+
],
|
| 352 |
+
"strict_notes": ["short note about exactness, ordering, abstention, or hallucination handling"]
|
| 353 |
+
}}"""
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def build_atomic_eval_prompt(task, question, answer, response, rubric):
|
| 357 |
+
rubric_str = json.dumps(
|
| 358 |
+
{
|
| 359 |
+
"required_atoms": rubric["required_atoms"],
|
| 360 |
+
"strict_notes": rubric.get("strict_notes", []),
|
| 361 |
+
},
|
| 362 |
+
ensure_ascii=False,
|
| 363 |
+
indent=2,
|
| 364 |
+
)
|
| 365 |
+
return f"""You are an LLM-as-a-judge evaluating one model response against a reference answer.
|
| 366 |
+
|
| 367 |
+
Question type: {task}
|
| 368 |
+
Question:
|
| 369 |
+
{question}
|
| 370 |
+
|
| 371 |
+
Reference answer:
|
| 372 |
+
{answer}
|
| 373 |
+
|
| 374 |
+
Model response:
|
| 375 |
+
{response}
|
| 376 |
+
|
| 377 |
+
Atomic grading rubric:
|
| 378 |
+
{rubric_str}
|
| 379 |
+
|
| 380 |
+
Question-type guidance:
|
| 381 |
+
{question_type_guidance(task)}
|
| 382 |
+
|
| 383 |
+
Judge each atom independently.
|
| 384 |
+
|
| 385 |
+
Atom labels:
|
| 386 |
+
- correct: the response fully satisfies this atom, allowing semantic paraphrase.
|
| 387 |
+
- partially_correct: the response gets the main idea but is incomplete or slightly underspecified. Use this sparingly. Do not use it for wrong counts, wrong years, wrong named entities, wrong ordering, or a concrete answer to an Information Absence question.
|
| 388 |
+
- missing: the response does not address this atom.
|
| 389 |
+
- incorrect: the response contradicts this atom or gives the wrong count/entity/date/order/current state.
|
| 390 |
+
|
| 391 |
+
Also identify unsupported_or_contradictory material:
|
| 392 |
+
- severity "material": extra answer content that changes the final answer, adds extra items to an exact list/count, gives an outdated current state, fabricates a concrete answer for Information Absence, or contradicts any atom.
|
| 393 |
+
- severity "minor": harmless context or extra explanation that does not change the answer.
|
| 394 |
+
|
| 395 |
+
Return JSON only with this schema:
|
| 396 |
+
{{
|
| 397 |
+
"atom_judgments": [
|
| 398 |
+
{{
|
| 399 |
+
"id": "a1",
|
| 400 |
+
"label": "correct|partially_correct|missing|incorrect",
|
| 401 |
+
"rationale": "brief reason"
|
| 402 |
+
}}
|
| 403 |
+
],
|
| 404 |
+
"unsupported_or_contradictory": [
|
| 405 |
+
{{
|
| 406 |
+
"text": "extra or contradictory claim",
|
| 407 |
+
"severity": "minor|material",
|
| 408 |
+
"rationale": "brief reason"
|
| 409 |
+
}}
|
| 410 |
+
],
|
| 411 |
+
"absence_mismatch": false,
|
| 412 |
+
"overall_rationale": "one or two sentence summary"
|
| 413 |
+
}}"""
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def llm_call(
|
| 417 |
+
deployment_name: str,
|
| 418 |
+
api_version: str,
|
| 419 |
+
_prompt: str,
|
| 420 |
+
debug: bool = False,
|
| 421 |
+
vllm: bool = False,
|
| 422 |
+
tritonai: bool = False,
|
| 423 |
+
nvidia: bool = False,
|
| 424 |
+
):
|
| 425 |
+
if nvidia:
|
| 426 |
+
client = OpenAI(
|
| 427 |
+
api_key=os.getenv("NV_API_KEY"),
|
| 428 |
+
base_url="https://inference-api.nvidia.com",
|
| 429 |
+
)
|
| 430 |
+
while True:
|
| 431 |
+
try:
|
| 432 |
+
return client.chat.completions.create(
|
| 433 |
+
model=deployment_name,
|
| 434 |
+
messages=[{"role": "user", "content": _prompt}],
|
| 435 |
+
)
|
| 436 |
+
except Exception as e:
|
| 437 |
+
st = _retryable_status(e)
|
| 438 |
+
if st in (429, 500, 503, 403):
|
| 439 |
+
print(f"[WARN] HTTP {st} from NVIDIA API; sleeping 60s then retrying...", flush=True)
|
| 440 |
+
time.sleep(60)
|
| 441 |
+
continue
|
| 442 |
+
print("One exception captured", repr(e), flush=True)
|
| 443 |
+
raise
|
| 444 |
+
|
| 445 |
+
if tritonai:
|
| 446 |
+
client = OpenAI(
|
| 447 |
+
api_key=os.getenv("TRITONAI_API_KEY"),
|
| 448 |
+
base_url=TRITONAI_BASE_URL,
|
| 449 |
+
)
|
| 450 |
+
while True:
|
| 451 |
+
try:
|
| 452 |
+
return client.chat.completions.create(
|
| 453 |
+
model=deployment_name,
|
| 454 |
+
messages=[{"role": "user", "content": _prompt}],
|
| 455 |
+
)
|
| 456 |
+
except Exception as e:
|
| 457 |
+
st = _retryable_status(e)
|
| 458 |
+
if st in (429, 500, 503, 403):
|
| 459 |
+
print(f"[WARN] HTTP {st} from LiteLLM proxy; sleeping 60s then retrying...", flush=True)
|
| 460 |
+
time.sleep(60)
|
| 461 |
+
continue
|
| 462 |
+
print("One exception captured", repr(e), flush=True)
|
| 463 |
+
raise
|
| 464 |
+
|
| 465 |
+
if deployment_name.startswith("claude-"):
|
| 466 |
+
import anthropic
|
| 467 |
+
|
| 468 |
+
client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
|
| 469 |
+
while True:
|
| 470 |
+
try:
|
| 471 |
+
msg = client.messages.create(
|
| 472 |
+
model=deployment_name,
|
| 473 |
+
max_tokens=1024,
|
| 474 |
+
messages=[{"role": "user", "content": _prompt}],
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
class _Choice:
|
| 478 |
+
class _Msg:
|
| 479 |
+
def __init__(self, text):
|
| 480 |
+
self.content = text
|
| 481 |
+
|
| 482 |
+
def __init__(self, text):
|
| 483 |
+
self.message = self._Msg(text)
|
| 484 |
+
|
| 485 |
+
class _Completion:
|
| 486 |
+
def __init__(self, text):
|
| 487 |
+
self.choices = [_Choice(text)]
|
| 488 |
+
|
| 489 |
+
return _Completion(msg.content[0].text)
|
| 490 |
+
except Exception as e:
|
| 491 |
+
st = _retryable_status(e)
|
| 492 |
+
if st in (429, 500, 503, 403):
|
| 493 |
+
print(f"[WARN] HTTP {st} from Anthropic; sleeping 60s then retrying...", flush=True)
|
| 494 |
+
time.sleep(60)
|
| 495 |
+
continue
|
| 496 |
+
print("One exception captured", repr(e), flush=True)
|
| 497 |
+
raise
|
| 498 |
+
|
| 499 |
+
if vllm:
|
| 500 |
+
client = OpenAI(
|
| 501 |
+
base_url=os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1"),
|
| 502 |
+
api_key=os.getenv("VLLM_API_KEY", "EMPTY"),
|
| 503 |
+
)
|
| 504 |
+
elif debug:
|
| 505 |
+
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 506 |
+
else:
|
| 507 |
+
client = AzureOpenAI(
|
| 508 |
+
azure_endpoint=endpoint,
|
| 509 |
+
azure_ad_token_provider=credential,
|
| 510 |
+
api_version=api_version,
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
kwargs = {
|
| 514 |
+
"model": deployment_name,
|
| 515 |
+
"messages": [{"role": "system", "content": _prompt}],
|
| 516 |
+
}
|
| 517 |
+
while True:
|
| 518 |
+
try:
|
| 519 |
+
return client.chat.completions.create(**kwargs)
|
| 520 |
+
except Exception as e:
|
| 521 |
+
st = _retryable_status(e)
|
| 522 |
+
if st in (429, 500, 503, 403):
|
| 523 |
+
print(f"[WARN] HTTP {st} from LLM; sleeping 120s then retrying...", flush=True)
|
| 524 |
+
time.sleep(120)
|
| 525 |
+
continue
|
| 526 |
+
print("One exception captured", repr(e), flush=True)
|
| 527 |
+
raise
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def call_json_llm(prompt, deployment_name, api_version, args, max_retries=3):
|
| 531 |
+
last_error = None
|
| 532 |
+
for attempt in range(max_retries):
|
| 533 |
+
completion = llm_call(
|
| 534 |
+
deployment_name,
|
| 535 |
+
api_version,
|
| 536 |
+
prompt,
|
| 537 |
+
debug=args.debug,
|
| 538 |
+
vllm=args.vllm,
|
| 539 |
+
tritonai=args.tritonai,
|
| 540 |
+
nvidia=args.nvidia,
|
| 541 |
+
)
|
| 542 |
+
content = completion.choices[0].message.content.strip()
|
| 543 |
+
try:
|
| 544 |
+
return parse_json_object(content), content
|
| 545 |
+
except Exception as e:
|
| 546 |
+
last_error = e
|
| 547 |
+
if attempt < max_retries - 1:
|
| 548 |
+
print(f"[WARN] Failed to parse judge JSON; retrying ({attempt + 1}/{max_retries})", flush=True)
|
| 549 |
+
time.sleep(2)
|
| 550 |
+
raise ValueError(f"Failed to parse JSON response from judge: {last_error}")
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def fallback_rubric(qid, task, question, answer):
|
| 554 |
+
return {
|
| 555 |
+
"question_id": qid,
|
| 556 |
+
"question_type": task,
|
| 557 |
+
"question": question,
|
| 558 |
+
"reference_answer": answer,
|
| 559 |
+
"required_atoms": [
|
| 560 |
+
{
|
| 561 |
+
"id": "a1",
|
| 562 |
+
"requirement": f"Response must correctly answer the question according to the reference answer: {answer}",
|
| 563 |
+
"weight": 1.0,
|
| 564 |
+
}
|
| 565 |
+
],
|
| 566 |
+
"strict_notes": ["Fallback single-atom rubric produced because the generated rubric was invalid."],
|
| 567 |
+
"prompt_version": ATOMIC_PROMPT_VERSION,
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def normalize_rubric(qid, task, question, answer, parsed):
|
| 572 |
+
atoms = parsed.get("required_atoms", []) if isinstance(parsed, dict) else []
|
| 573 |
+
norm_atoms = []
|
| 574 |
+
for idx, atom in enumerate(atoms, start=1):
|
| 575 |
+
if not isinstance(atom, dict):
|
| 576 |
+
continue
|
| 577 |
+
requirement = str(atom.get("requirement", "")).strip()
|
| 578 |
+
if not requirement:
|
| 579 |
+
continue
|
| 580 |
+
atom_id = str(atom.get("id", f"a{idx}")).strip() or f"a{idx}"
|
| 581 |
+
try:
|
| 582 |
+
weight = float(atom.get("weight", 1.0))
|
| 583 |
+
except (TypeError, ValueError):
|
| 584 |
+
weight = 1.0
|
| 585 |
+
if weight <= 0:
|
| 586 |
+
weight = 1.0
|
| 587 |
+
norm_atoms.append({"id": atom_id, "requirement": requirement, "weight": weight})
|
| 588 |
+
if not norm_atoms:
|
| 589 |
+
return fallback_rubric(qid, task, question, answer)
|
| 590 |
+
strict_notes = parsed.get("strict_notes", [])
|
| 591 |
+
if not isinstance(strict_notes, list):
|
| 592 |
+
strict_notes = [str(strict_notes)]
|
| 593 |
+
return {
|
| 594 |
+
"question_id": qid,
|
| 595 |
+
"question_type": task,
|
| 596 |
+
"question": question,
|
| 597 |
+
"reference_answer": answer,
|
| 598 |
+
"required_atoms": norm_atoms,
|
| 599 |
+
"strict_notes": [str(x) for x in strict_notes],
|
| 600 |
+
"prompt_version": ATOMIC_PROMPT_VERSION,
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
def get_or_build_rubric(qdata, rubric_data, rubric_file, deployment_name, api_version, args):
|
| 605 |
+
qid = qdata["question_id"]
|
| 606 |
+
existing = rubric_data["rubrics"].get(qid)
|
| 607 |
+
if (
|
| 608 |
+
existing
|
| 609 |
+
and existing.get("prompt_version") == ATOMIC_PROMPT_VERSION
|
| 610 |
+
and existing.get("required_atoms")
|
| 611 |
+
and not args.force_rebuild_rubric
|
| 612 |
+
):
|
| 613 |
+
return existing
|
| 614 |
+
|
| 615 |
+
task = qdata["question_type"]
|
| 616 |
+
prompt = build_rubric_prompt(task, qdata["question"], qdata["answer"])
|
| 617 |
+
try:
|
| 618 |
+
parsed, raw = call_json_llm(prompt, deployment_name, api_version, args)
|
| 619 |
+
rubric = normalize_rubric(qid, task, qdata["question"], qdata["answer"], parsed)
|
| 620 |
+
rubric["rubric_raw_response"] = raw
|
| 621 |
+
except Exception as e:
|
| 622 |
+
print(f"[WARN] Falling back to single-atom rubric for {qid}: {e}", flush=True)
|
| 623 |
+
rubric = fallback_rubric(qid, task, qdata["question"], qdata["answer"])
|
| 624 |
+
rubric_data["rubrics"][qid] = rubric
|
| 625 |
+
write_json(rubric_file, rubric_data)
|
| 626 |
+
return rubric
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
def compute_atomic_scores(rubric, parsed):
|
| 630 |
+
atoms = rubric.get("required_atoms", [])
|
| 631 |
+
judgments_by_id = {}
|
| 632 |
+
raw_judgments = parsed.get("atom_judgments", []) if isinstance(parsed, dict) else []
|
| 633 |
+
if isinstance(raw_judgments, list):
|
| 634 |
+
for judgment in raw_judgments:
|
| 635 |
+
if not isinstance(judgment, dict):
|
| 636 |
+
continue
|
| 637 |
+
atom_id = str(judgment.get("id", "")).strip()
|
| 638 |
+
label = str(judgment.get("label", "")).strip()
|
| 639 |
+
if label not in ATOM_SCORES:
|
| 640 |
+
label = "incorrect"
|
| 641 |
+
judgments_by_id[atom_id] = {
|
| 642 |
+
"id": atom_id,
|
| 643 |
+
"label": label,
|
| 644 |
+
"score": ATOM_SCORES[label],
|
| 645 |
+
"rationale": str(judgment.get("rationale", "")),
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
norm_judgments = []
|
| 649 |
+
weighted_score = 0.0
|
| 650 |
+
total_weight = 0.0
|
| 651 |
+
for atom in atoms:
|
| 652 |
+
atom_id = atom["id"]
|
| 653 |
+
weight = float(atom.get("weight", 1.0))
|
| 654 |
+
judgment = judgments_by_id.get(
|
| 655 |
+
atom_id,
|
| 656 |
+
{"id": atom_id, "label": "missing", "score": 0.0, "rationale": "No judgment returned."},
|
| 657 |
+
)
|
| 658 |
+
judgment["requirement"] = atom["requirement"]
|
| 659 |
+
judgment["weight"] = weight
|
| 660 |
+
norm_judgments.append(judgment)
|
| 661 |
+
weighted_score += judgment["score"] * weight
|
| 662 |
+
total_weight += weight
|
| 663 |
+
|
| 664 |
+
extras = parsed.get("unsupported_or_contradictory", []) if isinstance(parsed, dict) else []
|
| 665 |
+
if not isinstance(extras, list):
|
| 666 |
+
extras = []
|
| 667 |
+
material_extras = [
|
| 668 |
+
x for x in extras
|
| 669 |
+
if isinstance(x, dict) and str(x.get("severity", "")).strip() == "material"
|
| 670 |
+
]
|
| 671 |
+
absence_mismatch = bool(parsed.get("absence_mismatch", False)) if isinstance(parsed, dict) else False
|
| 672 |
+
|
| 673 |
+
strict_label = (
|
| 674 |
+
bool(norm_judgments)
|
| 675 |
+
and all(j["label"] == "correct" for j in norm_judgments)
|
| 676 |
+
and not material_extras
|
| 677 |
+
and not absence_mismatch
|
| 678 |
+
)
|
| 679 |
+
partial_score = weighted_score / total_weight if total_weight > 0 else 0.0
|
| 680 |
+
if absence_mismatch:
|
| 681 |
+
partial_score = 0.0
|
| 682 |
+
elif material_extras and partial_score > 0.8:
|
| 683 |
+
partial_score = 0.8
|
| 684 |
+
|
| 685 |
+
return {
|
| 686 |
+
"strict_label": strict_label,
|
| 687 |
+
"partial_score": round(partial_score, 4),
|
| 688 |
+
"atom_judgments": norm_judgments,
|
| 689 |
+
"unsupported_or_contradictory": extras,
|
| 690 |
+
"absence_mismatch": absence_mismatch,
|
| 691 |
+
"overall_rationale": str(parsed.get("overall_rationale", "")) if isinstance(parsed, dict) else "",
|
| 692 |
+
}
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
def judge_atomic(qdata, hypothesis, rubric, deployment_name, api_version, args):
|
| 696 |
+
prompt = build_atomic_eval_prompt(
|
| 697 |
+
qdata["question_type"],
|
| 698 |
+
qdata["question"],
|
| 699 |
+
qdata["answer"],
|
| 700 |
+
hypothesis,
|
| 701 |
+
rubric,
|
| 702 |
+
)
|
| 703 |
+
parsed, raw = call_json_llm(prompt, deployment_name, api_version, args)
|
| 704 |
+
scores = compute_atomic_scores(rubric, parsed)
|
| 705 |
+
return {
|
| 706 |
+
"model": args.eval_model_name,
|
| 707 |
+
"prompt_version": ATOMIC_PROMPT_VERSION,
|
| 708 |
+
"eval_mode": args.eval_mode,
|
| 709 |
+
"strict_label": scores["strict_label"],
|
| 710 |
+
"partial_score": scores["partial_score"],
|
| 711 |
+
"required_atoms": rubric["required_atoms"],
|
| 712 |
+
"atom_judgments": scores["atom_judgments"],
|
| 713 |
+
"unsupported_or_contradictory": scores["unsupported_or_contradictory"],
|
| 714 |
+
"absence_mismatch": scores["absence_mismatch"],
|
| 715 |
+
"overall_rationale": scores["overall_rationale"],
|
| 716 |
+
"raw_response": raw,
|
| 717 |
+
}
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
def should_skip_existing(existing_row, eval_mode):
|
| 721 |
+
if eval_mode == "legacy":
|
| 722 |
+
return "autoeval_label" in existing_row
|
| 723 |
+
atomic = existing_row.get("autoeval_atomic")
|
| 724 |
+
return bool(atomic and atomic.get("prompt_version") == ATOMIC_PROMPT_VERSION)
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
def safe_mean(values):
|
| 728 |
+
if not values:
|
| 729 |
+
return float("nan")
|
| 730 |
+
return float(np.mean(values))
|
| 731 |
+
|
| 732 |
+
|
| 733 |
+
def print_legacy_summary(logs, qtype2acc):
|
| 734 |
+
labels = [1 if x["autoeval_label"]["label"] else 0 for x in logs if "autoeval_label" in x]
|
| 735 |
+
print("Accuracy:", round(safe_mean(labels), 4))
|
| 736 |
+
for k, v in sorted(qtype2acc.items()):
|
| 737 |
+
print("\t{}: {} ({})".format(k, round(safe_mean(v), 4), len(v)))
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
def print_atomic_summary(logs, qtype2strict, qtype2partial, eval_mode):
|
| 741 |
+
strict_values = [
|
| 742 |
+
1 if x["autoeval_atomic"]["strict_label"] else 0
|
| 743 |
+
for x in logs
|
| 744 |
+
if "autoeval_atomic" in x
|
| 745 |
+
]
|
| 746 |
+
partial_values = [
|
| 747 |
+
float(x["autoeval_atomic"]["partial_score"])
|
| 748 |
+
for x in logs
|
| 749 |
+
if "autoeval_atomic" in x
|
| 750 |
+
]
|
| 751 |
+
if eval_mode in ("strict", "both"):
|
| 752 |
+
print("Strict Accuracy:", round(safe_mean(strict_values), 4))
|
| 753 |
+
for k, v in sorted(qtype2strict.items()):
|
| 754 |
+
print("\t{}: {} ({})".format(k, round(safe_mean(v), 4), len(v)))
|
| 755 |
+
if eval_mode in ("partial", "both"):
|
| 756 |
+
print("Partial Score:", round(safe_mean(partial_values), 4))
|
| 757 |
+
for k, v in sorted(qtype2partial.items()):
|
| 758 |
+
print("\t{}: {} ({})".format(k, round(safe_mean(v), 4), len(v)))
|
| 759 |
+
|
| 760 |
+
|
| 761 |
+
def main():
|
| 762 |
+
parser = argparse.ArgumentParser()
|
| 763 |
+
parser.add_argument("--hyp_file", type=str, default=None)
|
| 764 |
+
parser.add_argument("--ref_file", type=str, required=True)
|
| 765 |
+
parser.add_argument("--eval_model_name", type=str, required=True)
|
| 766 |
+
parser.add_argument(
|
| 767 |
+
"--eval_mode",
|
| 768 |
+
type=str,
|
| 769 |
+
default="both",
|
| 770 |
+
choices=["legacy", "strict", "partial", "both"],
|
| 771 |
+
help="legacy uses the old yes/no judge; strict/partial/both use atomic JSON judging.",
|
| 772 |
+
)
|
| 773 |
+
parser.add_argument("--rubric_file", type=str, default=None)
|
| 774 |
+
parser.add_argument("--build_rubric_only", action="store_true", default=False)
|
| 775 |
+
parser.add_argument("--force_rebuild_rubric", action="store_true", default=False)
|
| 776 |
+
parser.add_argument("--result_file", type=str, default=None)
|
| 777 |
+
parser.add_argument("--debug", action="store_true", default=False)
|
| 778 |
+
parser.add_argument("--vllm", action="store_true", default=False)
|
| 779 |
+
parser.add_argument("--tritonai", action="store_true", default=False,
|
| 780 |
+
help="Use OpenAI-compatible LiteLLM proxy (set TRITONAI_API_KEY env var)")
|
| 781 |
+
parser.add_argument("--nvidia", action="store_true", default=False,
|
| 782 |
+
help="Use NVIDIA inference API (set NV_API_KEY env var)")
|
| 783 |
+
parser.add_argument("--verbose", action=argparse.BooleanOptionalAction, default=True)
|
| 784 |
+
args = parser.parse_args()
|
| 785 |
+
|
| 786 |
+
if not args.build_rubric_only and not args.hyp_file:
|
| 787 |
+
parser.error("--hyp_file is required unless --build_rubric_only is set")
|
| 788 |
+
|
| 789 |
+
metric_model = args.eval_model_name
|
| 790 |
+
deployment_name, api_version = model_zoo[metric_model]
|
| 791 |
+
references = read_json_or_jsonl(args.ref_file)
|
| 792 |
+
qid2qdata = {entry["question_id"]: entry for entry in references}
|
| 793 |
+
qid2qtype = {entry["question_id"]: entry["question_type"] for entry in references}
|
| 794 |
+
qtypes = set(qid2qtype.values())
|
| 795 |
+
|
| 796 |
+
rubric_data = None
|
| 797 |
+
rubric_file = args.rubric_file or default_rubric_file(args.ref_file)
|
| 798 |
+
if args.eval_mode != "legacy" or args.build_rubric_only:
|
| 799 |
+
rubric_data = load_rubric_file(rubric_file, args.ref_file)
|
| 800 |
+
|
| 801 |
+
if args.build_rubric_only:
|
| 802 |
+
for entry in tqdm(references, desc="building rubrics"):
|
| 803 |
+
get_or_build_rubric(entry, rubric_data, rubric_file, deployment_name, api_version, args)
|
| 804 |
+
print(f"Saved rubric file to {rubric_file}")
|
| 805 |
+
return
|
| 806 |
+
|
| 807 |
+
result_file = args.result_file or default_result_file(args.hyp_file, metric_model, args.eval_mode)
|
| 808 |
+
existing = read_existing_jsonl(result_file)
|
| 809 |
+
hypotheses = read_json_or_jsonl(args.hyp_file)
|
| 810 |
+
|
| 811 |
+
qtype2acc = {t: [] for t in qtypes}
|
| 812 |
+
qtype2strict = {t: [] for t in qtypes}
|
| 813 |
+
qtype2partial = {t: [] for t in qtypes}
|
| 814 |
+
logs = []
|
| 815 |
+
|
| 816 |
+
for entry in tqdm(hypotheses):
|
| 817 |
+
qid = entry.get("question_id")
|
| 818 |
+
if qid not in qid2qtype:
|
| 819 |
+
if qid is not None:
|
| 820 |
+
print(f"Warning: skipping {qid} as it is not in reference data.")
|
| 821 |
+
continue
|
| 822 |
+
|
| 823 |
+
if qid in existing and should_skip_existing(existing[qid], args.eval_mode):
|
| 824 |
+
existing_row = existing[qid]
|
| 825 |
+
logs.append(existing_row)
|
| 826 |
+
qtype = qid2qtype[qid]
|
| 827 |
+
if args.eval_mode == "legacy":
|
| 828 |
+
label = existing_row["autoeval_label"]["label"]
|
| 829 |
+
qtype2acc[qtype].append(1 if label else 0)
|
| 830 |
+
else:
|
| 831 |
+
atomic = existing_row["autoeval_atomic"]
|
| 832 |
+
qtype2strict[qtype].append(1 if atomic["strict_label"] else 0)
|
| 833 |
+
qtype2partial[qtype].append(float(atomic["partial_score"]))
|
| 834 |
+
continue
|
| 835 |
+
|
| 836 |
+
qdata = qid2qdata[qid]
|
| 837 |
+
qtype = qdata["question_type"]
|
| 838 |
+
hyp = entry["hypothesis"]
|
| 839 |
+
|
| 840 |
+
if args.eval_mode == "legacy":
|
| 841 |
+
prompt = get_anscheck_prompt(qtype, qdata["question"], qdata["answer"], hyp)
|
| 842 |
+
completion = llm_call(
|
| 843 |
+
deployment_name,
|
| 844 |
+
api_version,
|
| 845 |
+
prompt,
|
| 846 |
+
debug=args.debug,
|
| 847 |
+
vllm=args.vllm,
|
| 848 |
+
tritonai=args.tritonai,
|
| 849 |
+
nvidia=args.nvidia,
|
| 850 |
+
)
|
| 851 |
+
eval_response = completion.choices[0].message.content.strip()
|
| 852 |
+
last_line = next((l.strip().lower() for l in reversed(eval_response.splitlines()) if l.strip()), "")
|
| 853 |
+
label = last_line == "yes" or last_line.startswith("yes")
|
| 854 |
+
row = dict(entry)
|
| 855 |
+
row["autoeval_label"] = {
|
| 856 |
+
"model": metric_model,
|
| 857 |
+
"prompt_version": LEGACY_PROMPT_VERSION,
|
| 858 |
+
"label": label,
|
| 859 |
+
"raw_response": eval_response,
|
| 860 |
+
}
|
| 861 |
+
logs.append(row)
|
| 862 |
+
qtype2acc[qtype].append(1 if label else 0)
|
| 863 |
+
if args.verbose:
|
| 864 |
+
print(json.dumps({
|
| 865 |
+
"question": qdata["question"],
|
| 866 |
+
"answer": qdata["answer"],
|
| 867 |
+
"hypothesis": hyp,
|
| 868 |
+
"autoeval_label": label,
|
| 869 |
+
}, indent=4), flush=True)
|
| 870 |
+
append_jsonl(result_file, row)
|
| 871 |
+
continue
|
| 872 |
+
|
| 873 |
+
rubric = get_or_build_rubric(qdata, rubric_data, rubric_file, deployment_name, api_version, args)
|
| 874 |
+
try:
|
| 875 |
+
atomic_eval = judge_atomic(qdata, hyp, rubric, deployment_name, api_version, args)
|
| 876 |
+
except ValueError as _judge_err:
|
| 877 |
+
print(f"[WARN] judge_atomic failed for {qdata['question_id']}, writing zero score: {_judge_err}", flush=True)
|
| 878 |
+
atoms = rubric.get("required_atoms", [])
|
| 879 |
+
atomic_eval = {
|
| 880 |
+
"model": deployment_name,
|
| 881 |
+
"prompt_version": ATOMIC_PROMPT_VERSION,
|
| 882 |
+
"eval_mode": args.eval_mode,
|
| 883 |
+
"strict_label": False,
|
| 884 |
+
"partial_score": 0.0,
|
| 885 |
+
"required_atoms": atoms,
|
| 886 |
+
"atom_judgments": [{"id": a["id"], "label": "error", "score": 0.0, "rationale": "judge parse error", "requirement": a.get("requirement", ""), "weight": a.get("weight", 1.0)} for a in atoms],
|
| 887 |
+
"unsupported_or_contradictory": [],
|
| 888 |
+
"absence_mismatch": False,
|
| 889 |
+
"overall_rationale": f"Skipped: judge JSON parse error ({_judge_err})",
|
| 890 |
+
}
|
| 891 |
+
row = dict(entry)
|
| 892 |
+
row["autoeval_atomic"] = atomic_eval
|
| 893 |
+
logs.append(row)
|
| 894 |
+
qtype2strict[qtype].append(1 if atomic_eval["strict_label"] else 0)
|
| 895 |
+
qtype2partial[qtype].append(float(atomic_eval["partial_score"]))
|
| 896 |
+
if args.verbose:
|
| 897 |
+
print(json.dumps({
|
| 898 |
+
"question": qdata["question"],
|
| 899 |
+
"answer": qdata["answer"],
|
| 900 |
+
"hypothesis": hyp,
|
| 901 |
+
"strict_label": atomic_eval["strict_label"],
|
| 902 |
+
"partial_score": atomic_eval["partial_score"],
|
| 903 |
+
"atom_judgments": atomic_eval["atom_judgments"],
|
| 904 |
+
}, indent=4), flush=True)
|
| 905 |
+
append_jsonl(result_file, row)
|
| 906 |
+
|
| 907 |
+
if args.eval_mode == "legacy":
|
| 908 |
+
print_legacy_summary(logs, qtype2acc)
|
| 909 |
+
else:
|
| 910 |
+
print_atomic_summary(logs, qtype2strict, qtype2partial, args.eval_mode)
|
| 911 |
+
print(f"Rubric file: {rubric_file}")
|
| 912 |
+
print("Saved to", result_file)
|
| 913 |
+
|
| 914 |
+
|
| 915 |
+
if __name__ == "__main__":
|
| 916 |
+
main()
|
main.py
ADDED
|
@@ -0,0 +1,1717 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import json
|
| 5 |
+
import time
|
| 6 |
+
from json import JSONDecodeError
|
| 7 |
+
from datetime import datetime, timedelta
|
| 8 |
+
from typing import List, Dict, Any
|
| 9 |
+
|
| 10 |
+
from openai import OpenAI
|
| 11 |
+
try:
|
| 12 |
+
from openai import AzureOpenAI
|
| 13 |
+
from azure.identity import ChainedTokenCredential, AzureCliCredential, ManagedIdentityCredential, get_bearer_token_provider
|
| 14 |
+
# Azure scope for the OAuth bearer-token provider; override per deployment.
|
| 15 |
+
AZURE_OAUTH_SCOPE = os.environ.get("AZURE_OAUTH_SCOPE", "")
|
| 16 |
+
if AZURE_OAUTH_SCOPE:
|
| 17 |
+
credential = get_bearer_token_provider(ChainedTokenCredential(
|
| 18 |
+
AzureCliCredential(),
|
| 19 |
+
ManagedIdentityCredential(),
|
| 20 |
+
), AZURE_OAUTH_SCOPE)
|
| 21 |
+
else:
|
| 22 |
+
credential = None
|
| 23 |
+
except ImportError:
|
| 24 |
+
AzureOpenAI = None
|
| 25 |
+
credential = None
|
| 26 |
+
|
| 27 |
+
from model_zoo import model_zoo
|
| 28 |
+
from memory import EpisodicMemoryStore, SemanticMemoryStore
|
| 29 |
+
try:
|
| 30 |
+
import tiktoken
|
| 31 |
+
except ImportError:
|
| 32 |
+
tiktoken = None
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
| 36 |
+
except ImportError:
|
| 37 |
+
AutoTokenizer = None
|
| 38 |
+
PreTrainedTokenizerBase = ()
|
| 39 |
+
from collections import defaultdict
|
| 40 |
+
|
| 41 |
+
def get_hf_tokenizer_for_vllm(model_name: str):
|
| 42 |
+
return AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Azure OpenAI endpoint (set AZURE_OPENAI_ENDPOINT env var to your deployment URL).
|
| 46 |
+
endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT", "")
|
| 47 |
+
|
| 48 |
+
# OpenAI-compatible LiteLLM proxy URL (set LITELLM_BASE_URL env var to your proxy).
|
| 49 |
+
TRITONAI_BASE_URL = os.environ.get("LITELLM_BASE_URL", "")
|
| 50 |
+
|
| 51 |
+
# reading cached files
|
| 52 |
+
qid2plan = {}
|
| 53 |
+
if 'plan_cache' in os.environ:
|
| 54 |
+
plan_cache_file = os.environ['plan_cache']
|
| 55 |
+
if os.path.exists(plan_cache_file):
|
| 56 |
+
qid2plan = json.load(open(plan_cache_file))
|
| 57 |
+
else:
|
| 58 |
+
plan_cache_file = 'response_cache/qa/evolv_mem_v3_plan_cache_gpt5-1'
|
| 59 |
+
|
| 60 |
+
veri_reading_log_file = os.environ['reading_cache']
|
| 61 |
+
qid2rel_sess_ids = {}
|
| 62 |
+
if os.path.exists(veri_reading_log_file):
|
| 63 |
+
qid2rel_sess_ids = json.load(open(veri_reading_log_file))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# Cache file for retrieval results to avoid re-running expensive retrieval operations.
|
| 67 |
+
# Stores pre-computed search results for questions, including:
|
| 68 |
+
# - Question metadata (id, type, text, answer, dates)
|
| 69 |
+
# - Haystack information (session dates, content, IDs)
|
| 70 |
+
# - Retrieved results with query, ranked items, and evaluation metrics
|
| 71 |
+
# Format: JSONL file where each line contains a complete retrieval result for one question
|
| 72 |
+
retrieved_log_file = None
|
| 73 |
+
if 'ret_cache' in os.environ:
|
| 74 |
+
retrieved_log_file = os.environ['ret_cache']
|
| 75 |
+
|
| 76 |
+
print("loading existing retrieved results ...")
|
| 77 |
+
retrieved_data = [json.loads(line) for line in open(retrieved_log_file).readlines()]
|
| 78 |
+
retrieved_data_dict = {x['question_id']: x for x in retrieved_data}
|
| 79 |
+
valid_sess_set = set(json.load(open("dataset/all_sessions.json")).keys())
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def parse_json(response_content):
|
| 83 |
+
"""Safely parse JSON content from a string response."""
|
| 84 |
+
candidates = []
|
| 85 |
+
|
| 86 |
+
if '```json' in response_content:
|
| 87 |
+
start_idx = response_content.find('```json') + 7
|
| 88 |
+
end_idx = response_content.rfind('```')
|
| 89 |
+
if end_idx > start_idx:
|
| 90 |
+
candidates.append(response_content[start_idx:end_idx].strip())
|
| 91 |
+
|
| 92 |
+
# Also try brace-based extraction as fallback
|
| 93 |
+
brace_start = response_content.find('{')
|
| 94 |
+
brace_end = response_content.rfind('}') + 1
|
| 95 |
+
if brace_start >= 0 and brace_end > brace_start:
|
| 96 |
+
candidates.append(response_content[brace_start:brace_end].strip())
|
| 97 |
+
|
| 98 |
+
for json_block in candidates:
|
| 99 |
+
try:
|
| 100 |
+
result = json.loads(json_block)
|
| 101 |
+
return result
|
| 102 |
+
except (JSONDecodeError, ValueError):
|
| 103 |
+
continue
|
| 104 |
+
|
| 105 |
+
print(f"[Warning] Failed to decode JSON from response (all strategies failed)")
|
| 106 |
+
print(f"[Debug] Raw response content (truncated): {response_content[:500]}")
|
| 107 |
+
return {}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _retryable_status(e):
|
| 111 |
+
# Try common attributes first — return any extractable HTTP status code
|
| 112 |
+
status = getattr(e, "status_code", None) or getattr(e, "http_status", None)
|
| 113 |
+
if status is not None:
|
| 114 |
+
return int(status)
|
| 115 |
+
resp = getattr(e, "response", None)
|
| 116 |
+
if resp is not None and getattr(resp, "status_code", None) is not None:
|
| 117 |
+
return int(resp.status_code)
|
| 118 |
+
# Fallback: infer from message text
|
| 119 |
+
msg = str(e).lower()
|
| 120 |
+
if "429" in msg or "rate limit" in msg:
|
| 121 |
+
return 429
|
| 122 |
+
if "500" in msg or "internal server error" in msg:
|
| 123 |
+
return 500
|
| 124 |
+
if "503" in msg or "API Configuration unavailable" in msg:
|
| 125 |
+
return 503
|
| 126 |
+
return None
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
MAX_CONTEXT_TOKENS = 272_000
|
| 130 |
+
#MAX_CONTEXT_TOKENS = 256_000
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class CharacterEncoder:
|
| 134 |
+
"""Conservative fallback when tokenizer packages are unavailable."""
|
| 135 |
+
|
| 136 |
+
def encode(self, text, **kwargs):
|
| 137 |
+
return list(text)
|
| 138 |
+
|
| 139 |
+
def decode(self, toks):
|
| 140 |
+
return "".join(toks)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def _get_encoder(model_name: str):
|
| 144 |
+
"""
|
| 145 |
+
Return a token encoder for the given model name.
|
| 146 |
+
Cached to avoid reloading HF tokenizers on every call.
|
| 147 |
+
"""
|
| 148 |
+
# Prefer explicit handling for Qwen models first
|
| 149 |
+
# Note: Adjust the path below if you have the model downloaded locally
|
| 150 |
+
if AutoTokenizer is not None and any(k in model_name for k in ["Qwen3", "Qwen/", "Qwen"]):
|
| 151 |
+
try:
|
| 152 |
+
return AutoTokenizer.from_pretrained(
|
| 153 |
+
"Qwen/Qwen3-30B-A3B-Instruct-2507",
|
| 154 |
+
trust_remote_code=True,
|
| 155 |
+
use_fast=False
|
| 156 |
+
)
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"[WARN] Failed to load Qwen tokenizer: {e}. Falling back to tiktoken.")
|
| 159 |
+
|
| 160 |
+
# For non-Qwen models, rely on tiktoken's mapping when possible
|
| 161 |
+
if tiktoken is not None:
|
| 162 |
+
try:
|
| 163 |
+
return tiktoken.encoding_for_model(model_name)
|
| 164 |
+
except Exception:
|
| 165 |
+
# Generic safe fallback
|
| 166 |
+
return tiktoken.get_encoding("cl100k_base")
|
| 167 |
+
return CharacterEncoder()
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def _truncate_to_tokens(text, enc, max_tokens) -> str:
|
| 171 |
+
"""
|
| 172 |
+
Truncates text to the last `max_tokens`.
|
| 173 |
+
Compatible with both tiktoken and Hugging Face AutoTokenizers.
|
| 174 |
+
"""
|
| 175 |
+
# Handle Hugging Face Tokenizers
|
| 176 |
+
if PreTrainedTokenizerBase and isinstance(enc, PreTrainedTokenizerBase):
|
| 177 |
+
# add_special_tokens=False is crucial here to avoid double counting
|
| 178 |
+
# or inserting BOS/EOS in the middle of text during length checks
|
| 179 |
+
toks = enc.encode(text, add_special_tokens=False)
|
| 180 |
+
# Handle tiktoken
|
| 181 |
+
else:
|
| 182 |
+
toks = enc.encode(text, disallowed_special=())
|
| 183 |
+
|
| 184 |
+
if len(toks) <= max_tokens:
|
| 185 |
+
return text
|
| 186 |
+
|
| 187 |
+
# Keep the tail (usually the most relevant for instructions / recent context)
|
| 188 |
+
toks = toks[-max_tokens:]
|
| 189 |
+
|
| 190 |
+
return enc.decode(toks)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def truncate_chat_prompt(tokenizer, messages, max_context, max_output, overhead=256):
|
| 194 |
+
# Apply the model's chat template so token counting matches the server
|
| 195 |
+
prompt_text = tokenizer.apply_chat_template(
|
| 196 |
+
messages,
|
| 197 |
+
tokenize=False,
|
| 198 |
+
add_generation_prompt=True,
|
| 199 |
+
)
|
| 200 |
+
input_ids = tokenizer(prompt_text, add_special_tokens=False).input_ids
|
| 201 |
+
|
| 202 |
+
budget = max_context - max_output - overhead
|
| 203 |
+
if budget < 0:
|
| 204 |
+
raise ValueError("max_output_tokens + overhead exceeds max_context_tokens")
|
| 205 |
+
|
| 206 |
+
if len(input_ids) > budget:
|
| 207 |
+
input_ids = input_ids[-budget:] # or keep the *start* depending on your needs
|
| 208 |
+
prompt_text = tokenizer.decode(input_ids, skip_special_tokens=False)
|
| 209 |
+
|
| 210 |
+
return prompt_text
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def llm_call(deployment_name: str,
|
| 214 |
+
api_version: str,
|
| 215 |
+
_prompt: str,
|
| 216 |
+
max_context_tokens: int = MAX_CONTEXT_TOKENS,
|
| 217 |
+
max_output_tokens: int = 1024,
|
| 218 |
+
extra_overhead_tokens: int = 32,
|
| 219 |
+
debug: bool = False,
|
| 220 |
+
vllm: bool = False,
|
| 221 |
+
tritonai: bool = False,
|
| 222 |
+
nvidia: bool = False):
|
| 223 |
+
if nvidia:
|
| 224 |
+
client = OpenAI(
|
| 225 |
+
api_key=os.getenv("NV_API_KEY"),
|
| 226 |
+
base_url="https://inference-api.nvidia.com/v1",
|
| 227 |
+
)
|
| 228 |
+
elif tritonai:
|
| 229 |
+
client = OpenAI(
|
| 230 |
+
api_key=os.getenv("TRITONAI_API_KEY"),
|
| 231 |
+
base_url=TRITONAI_BASE_URL,
|
| 232 |
+
)
|
| 233 |
+
max_context_tokens = 131_072 # DeepSeek R1 128K context
|
| 234 |
+
# DeepSeek R1 uses thinking tokens before the answer; raise output budget
|
| 235 |
+
if max_output_tokens < 4096:
|
| 236 |
+
max_output_tokens = 4096
|
| 237 |
+
elif vllm:
|
| 238 |
+
# Use local vLLM OpenAI-compatible server
|
| 239 |
+
vllm_base_url = os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1")
|
| 240 |
+
vllm_api_key = os.getenv("VLLM_API_KEY", "EMPTY")
|
| 241 |
+
client = OpenAI(
|
| 242 |
+
base_url=vllm_base_url,
|
| 243 |
+
api_key=vllm_api_key,
|
| 244 |
+
)
|
| 245 |
+
# Override deployment_name with the vLLM-served model name when set
|
| 246 |
+
# (needed when main model is LiteLLM proxy but reading uses local vLLM)
|
| 247 |
+
deployment_name = os.getenv("VLLM_MODEL_NAME", deployment_name)
|
| 248 |
+
# vLLM Qwen3-30B-A3B-Instruct-2507 has 131,072-token context
|
| 249 |
+
max_context_tokens = 131_072
|
| 250 |
+
elif debug:
|
| 251 |
+
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 252 |
+
else:
|
| 253 |
+
client = AzureOpenAI(
|
| 254 |
+
azure_endpoint=endpoint,
|
| 255 |
+
azure_ad_token_provider=credential,
|
| 256 |
+
api_version=api_version,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
enc = _get_encoder(deployment_name)
|
| 260 |
+
|
| 261 |
+
# How many tokens we can spend on the input
|
| 262 |
+
budget = max_context_tokens - max_output_tokens - extra_overhead_tokens
|
| 263 |
+
if budget < 0:
|
| 264 |
+
raise ValueError("max_output_tokens + overhead exceeds max_context_tokens")
|
| 265 |
+
|
| 266 |
+
prompt_truncated = _truncate_to_tokens(_prompt, enc, budget)
|
| 267 |
+
|
| 268 |
+
# Strip control characters and fix broken Unicode that break JSON serialization
|
| 269 |
+
if nvidia or tritonai:
|
| 270 |
+
prompt_truncated = prompt_truncated.encode('utf-8', errors='replace').decode('utf-8', errors='replace')
|
| 271 |
+
prompt_truncated = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', prompt_truncated)
|
| 272 |
+
# Verify it's valid JSON-serializable
|
| 273 |
+
json.dumps(prompt_truncated)
|
| 274 |
+
|
| 275 |
+
# OpenAI-compatible proxy requires at least one user message; Azure accepts system role
|
| 276 |
+
msg_role = "user" if (tritonai or nvidia) else "system"
|
| 277 |
+
kwargs = {
|
| 278 |
+
'model': deployment_name,
|
| 279 |
+
'messages':[
|
| 280 |
+
{"role": msg_role, "content": prompt_truncated}
|
| 281 |
+
]
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
while True:
|
| 285 |
+
try:
|
| 286 |
+
completion = client.chat.completions.create(**kwargs)
|
| 287 |
+
break
|
| 288 |
+
except Exception as e:
|
| 289 |
+
from openai import APITimeoutError as _APITimeoutError
|
| 290 |
+
if isinstance(e, _APITimeoutError):
|
| 291 |
+
print(f"[WARN] APITimeoutError from LLM; sleeping 30s then retrying...", flush=True)
|
| 292 |
+
time.sleep(30)
|
| 293 |
+
continue
|
| 294 |
+
st = _retryable_status(e)
|
| 295 |
+
# 404 from LiteLLM proxy/Bedrock is intermittent (model temporarily unavailable)
|
| 296 |
+
retryable = (429, 500, 503, 403) + ((404,) if tritonai else ())
|
| 297 |
+
if st in retryable:
|
| 298 |
+
print(
|
| 299 |
+
f"[WARN] q_idx={di} HTTP {st} from LLM; sleeping 60s then retrying...",
|
| 300 |
+
flush=True
|
| 301 |
+
)
|
| 302 |
+
time.sleep(60)
|
| 303 |
+
continue
|
| 304 |
+
# Non-retryable -> re-raise
|
| 305 |
+
print('One exception captured', repr(e), flush=True)
|
| 306 |
+
raise
|
| 307 |
+
|
| 308 |
+
#answer = (completion.choices[0].message.content or "").strip()
|
| 309 |
+
#return answer
|
| 310 |
+
return completion
|
| 311 |
+
|
| 312 |
+
def custom_to_iso8601(time_str):
|
| 313 |
+
"""
|
| 314 |
+
Convert '2023/04/10 (Mon) 23:07' to '2023-04-10T23:07:00'
|
| 315 |
+
"""
|
| 316 |
+
# Remove the weekday (e.g., "(Mon)")
|
| 317 |
+
clean = time_str.split('(')[0].strip() + ' ' + time_str.split(')')[-1].strip()
|
| 318 |
+
# Parse the cleaned string
|
| 319 |
+
dt = datetime.strptime(clean, "%Y/%m/%d %H:%M")
|
| 320 |
+
# Format as ISO 8601
|
| 321 |
+
return dt.isoformat()
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def evaluate_retrieval(recalled_docs, correct_docs, k=10):
|
| 325 |
+
#recalled_docs = set(corpus_ids[idx] for idx in rankings[:k])
|
| 326 |
+
recall_any = float(any(doc in recalled_docs for doc in correct_docs))
|
| 327 |
+
recall_all = float(all(doc in recalled_docs for doc in correct_docs))
|
| 328 |
+
return recall_any, recall_all
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def print_average_metrics(retrieval_metric_list):
|
| 332 |
+
metric_sums = defaultdict(float)
|
| 333 |
+
metric_counts = defaultdict(int)
|
| 334 |
+
|
| 335 |
+
for metric in retrieval_metric_list:
|
| 336 |
+
for k, v in metric.items():
|
| 337 |
+
metric_sums[k] += v
|
| 338 |
+
metric_counts[k] += 1
|
| 339 |
+
|
| 340 |
+
print("\t\t Average metrics:")
|
| 341 |
+
for k in sorted(metric_sums):
|
| 342 |
+
avg = metric_sums[k] / metric_counts[k]
|
| 343 |
+
print(f"\t\t {k}: {avg:.4f}")
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
# Load prompt template
|
| 347 |
+
prompt_path = "prompts/agentic_retrieval_prompt.txt"
|
| 348 |
+
with open(prompt_path, "r", encoding="utf-8") as f:
|
| 349 |
+
stg_prompt = f.read()
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class ChatHistory:
|
| 353 |
+
def __init__(self, data: Dict[str, Any] = None, sessions: List = None):
|
| 354 |
+
assert not (data is not None and sessions is not None), "ChatHistory: Only one of data or sessions may be provided."
|
| 355 |
+
|
| 356 |
+
if data is not None: # From raw data dict
|
| 357 |
+
self.raw_data = data
|
| 358 |
+
self.sessions = []
|
| 359 |
+
self._parse_sessions()
|
| 360 |
+
elif sessions is not None: # From provided sessions list
|
| 361 |
+
self.sessions = sessions
|
| 362 |
+
self.messages = []
|
| 363 |
+
for sess in self.sessions:
|
| 364 |
+
session_id = sess['session_id']
|
| 365 |
+
timestamp = sess['timestamp']
|
| 366 |
+
for turn_idx, msg in enumerate(sess['session']):
|
| 367 |
+
entry = {
|
| 368 |
+
"role": msg.get("role"),
|
| 369 |
+
"content": msg.get("content"),
|
| 370 |
+
"session_id": session_id,
|
| 371 |
+
"turn_index": turn_idx,
|
| 372 |
+
"timestamp": timestamp,
|
| 373 |
+
"iso_datetime": timestamp.isoformat(),
|
| 374 |
+
"has_answer": msg.get("has_answer", False)
|
| 375 |
+
}
|
| 376 |
+
self.messages.append(entry)
|
| 377 |
+
else:
|
| 378 |
+
self.sessions = []
|
| 379 |
+
self.messages = []
|
| 380 |
+
|
| 381 |
+
def get_contents(self, granularity='turn', _format='json') -> list:
|
| 382 |
+
if granularity == "turn":
|
| 383 |
+
if _format == "json":
|
| 384 |
+
return [json.dumps(msg) for msg in self.messages]
|
| 385 |
+
else:
|
| 386 |
+
return [msg['content'] for msg in self.messages]
|
| 387 |
+
else: # granularity == 'session'
|
| 388 |
+
if _format == "json":
|
| 389 |
+
return [json.dumps(session) for session in self.sessions]
|
| 390 |
+
else:
|
| 391 |
+
return [json.dumps({"role": session["role"], "content": session["content"]})
|
| 392 |
+
for session in self.sessions]
|
| 393 |
+
|
| 394 |
+
def to_prompt(self, granularity='session', _format="json"):
|
| 395 |
+
history_str = ""
|
| 396 |
+
for session in self.sessions:
|
| 397 |
+
sess_str = json.dumps([{"role": x["role"], "content": x["content"]} for x in session['session']])
|
| 398 |
+
history_str += f"Session Date: {session['session_date']}\nSession Content:\n{sess_str}\n"
|
| 399 |
+
return history_str
|
| 400 |
+
|
| 401 |
+
def get_session_ids(self):
|
| 402 |
+
return [s['session_id'] for s in self.sessions]
|
| 403 |
+
|
| 404 |
+
@staticmethod
|
| 405 |
+
def _parse_date(date_str: str) -> datetime:
|
| 406 |
+
# Convert '2023/04/10 (Mon) 17:50' to datetime
|
| 407 |
+
# Remove weekday in parentheses
|
| 408 |
+
date_part, time_part = date_str.split('(')[0].strip(), date_str.split(')')[-1].strip()
|
| 409 |
+
dt = datetime.strptime(date_part + time_part, "%Y/%m/%d%H:%M")
|
| 410 |
+
return dt
|
| 411 |
+
|
| 412 |
+
def _parse_sessions(self):
|
| 413 |
+
"""
|
| 414 |
+
Flattens sessions into a list of messages, each with ISO 8601 date, session ID, and turn index
|
| 415 |
+
"""
|
| 416 |
+
self.messages = []
|
| 417 |
+
for date_str, session_id, session, topic in zip(
|
| 418 |
+
self.raw_data['haystack_dates'],
|
| 419 |
+
self.raw_data['haystack_session_ids'],
|
| 420 |
+
self.raw_data['haystack_sessions'],
|
| 421 |
+
self.raw_data['haystack_topics']
|
| 422 |
+
):
|
| 423 |
+
timestamp = self._parse_date(date_str)
|
| 424 |
+
for turn_idx, msg in enumerate(session):
|
| 425 |
+
entry = {
|
| 426 |
+
"role": msg.get("role"),
|
| 427 |
+
"content": msg.get("content"),
|
| 428 |
+
"session_id": session_id,
|
| 429 |
+
"turn_index": turn_idx,
|
| 430 |
+
"timestamp": timestamp,
|
| 431 |
+
"iso_datetime": timestamp.isoformat(),
|
| 432 |
+
"session_date": date_str,
|
| 433 |
+
"has_answer": msg.get("has_answer", False)
|
| 434 |
+
}
|
| 435 |
+
self.messages.append(entry)
|
| 436 |
+
self.sessions.append({
|
| 437 |
+
"session_date": date_str,
|
| 438 |
+
"timestamp": timestamp,
|
| 439 |
+
"session_id": session_id,
|
| 440 |
+
"session": session,
|
| 441 |
+
"topic": topic,
|
| 442 |
+
})
|
| 443 |
+
# Optionally, sort by time (ascending)
|
| 444 |
+
#self.sessions.sort(key=lambda x: x["timestamp"])
|
| 445 |
+
#self.messages.sort(key=lambda x: x["timestamp"])
|
| 446 |
+
|
| 447 |
+
def __len__(self):
|
| 448 |
+
return len(self.sessions)
|
| 449 |
+
|
| 450 |
+
def __getitem__(self, idx) -> Dict[str, any]:
|
| 451 |
+
return self.sessions[idx] # Return session dict
|
| 452 |
+
|
| 453 |
+
def get_item_by_index(self, idx):
|
| 454 |
+
if isinstance(idx, range) or isinstance(idx, list):
|
| 455 |
+
max_idx = len(self.sessions)
|
| 456 |
+
valid_indices = [i for i in idx if 0 <= i < max_idx]
|
| 457 |
+
selected_sessions = [self.sessions[i] for i in valid_indices]
|
| 458 |
+
return ChatHistory(sessions=selected_sessions)
|
| 459 |
+
else:
|
| 460 |
+
raise ValueError("Input must be a list or range of indices.")
|
| 461 |
+
|
| 462 |
+
def get_item_by_session_ids(self, sess_set):
|
| 463 |
+
if not isinstance(sess_set, set):
|
| 464 |
+
sess_set = set(sess_set)
|
| 465 |
+
new_sessions = []
|
| 466 |
+
for sess in self.sessions:
|
| 467 |
+
if sess['session_id'] in sess_set:
|
| 468 |
+
new_sessions.append(sess)
|
| 469 |
+
|
| 470 |
+
return ChatHistory(sessions=new_sessions)
|
| 471 |
+
|
| 472 |
+
def get_item_by_ranked_session(self, sess_id_sorted):
|
| 473 |
+
new_sessions = []
|
| 474 |
+
for sess_id in sess_id_sorted:
|
| 475 |
+
for sess in self.sessions:
|
| 476 |
+
if sess['session_id'] in sess_id:
|
| 477 |
+
new_sessions.append(sess)
|
| 478 |
+
|
| 479 |
+
return ChatHistory(sessions=new_sessions)
|
| 480 |
+
|
| 481 |
+
def get_item_by_topics(self, topics):
|
| 482 |
+
new_sessions = []
|
| 483 |
+
new_sess_ids = set()
|
| 484 |
+
for sess in self.sessions:
|
| 485 |
+
for tp in sess['topic']:
|
| 486 |
+
if tp in topics and sess['session_id'] not in new_sess_ids:
|
| 487 |
+
new_sessions.append(sess)
|
| 488 |
+
new_sess_ids.add(sess['session_id'])
|
| 489 |
+
break
|
| 490 |
+
|
| 491 |
+
return ChatHistory(sessions=new_sessions)
|
| 492 |
+
|
| 493 |
+
def merge_rel_sess(self, new_sessions):
|
| 494 |
+
# Gather all current and new sessions in a dict keyed by session_id
|
| 495 |
+
all_sessions = {s["session_id"]: s for s in self.sessions}
|
| 496 |
+
|
| 497 |
+
# add if new
|
| 498 |
+
for s in new_sessions.sessions:
|
| 499 |
+
if s["session_id"] not in all_sessions:
|
| 500 |
+
all_sessions[s["session_id"]] = s
|
| 501 |
+
|
| 502 |
+
# Reconstruct raw_data for new ChatHistory
|
| 503 |
+
merged_raw_data = {
|
| 504 |
+
"haystack_dates": [s["session_date"] for k, s in all_sessions.items()],
|
| 505 |
+
"haystack_session_ids": [s["session_id"] for k, s in all_sessions.items()],
|
| 506 |
+
"haystack_sessions": [s["session"] for k, s in all_sessions.items()],
|
| 507 |
+
"haystack_topics": [s["topic"] for k, s in all_sessions.items()],
|
| 508 |
+
}
|
| 509 |
+
self.sessions = ChatHistory(merged_raw_data)
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def generate_keywords(question: str, deployment_name, api_version, debug=False, vllm=None,
|
| 514 |
+
tritonai=False, nvidia=False) -> List[str]:
|
| 515 |
+
# read prompt from `keyword_search_prompt.txt` file
|
| 516 |
+
with open('prompts/keyword_search_prompt.txt') as f:
|
| 517 |
+
prompt_template = f.read()
|
| 518 |
+
|
| 519 |
+
prompt = prompt_template + question
|
| 520 |
+
# Call the LLM to generate keywords
|
| 521 |
+
completion = llm_call(
|
| 522 |
+
deployment_name,
|
| 523 |
+
api_version,
|
| 524 |
+
prompt,
|
| 525 |
+
debug=debug,
|
| 526 |
+
vllm=vllm,
|
| 527 |
+
tritonai=tritonai,
|
| 528 |
+
nvidia=nvidia,
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
response_content = (completion.choices[0].message.content or "").strip()
|
| 532 |
+
result = parse_json(response_content)
|
| 533 |
+
keywords = result["keywords"] if "keywords" in result else []
|
| 534 |
+
return keywords
|
| 535 |
+
|
| 536 |
+
def keyword_search(chat_history: ChatHistory, keywords: list):
|
| 537 |
+
print(f"\t\t** keyword search **: {keywords}")
|
| 538 |
+
# Gather all messages that match
|
| 539 |
+
start_time = time.time()
|
| 540 |
+
matched_msgs = [
|
| 541 |
+
msg for msg in chat_history.messages
|
| 542 |
+
if any(kw.lower() in (msg.get("content") or "").lower() for kw in keywords)
|
| 543 |
+
]
|
| 544 |
+
end_time = time.time()
|
| 545 |
+
execution_time = end_time - start_time
|
| 546 |
+
|
| 547 |
+
if matched_msgs:
|
| 548 |
+
new_sess_ids = set()
|
| 549 |
+
for msg in matched_msgs:
|
| 550 |
+
key = msg["session_id"]
|
| 551 |
+
new_sess_ids.add(key)
|
| 552 |
+
new_chat_history = chat_history.get_item_by_session_ids(new_sess_ids)
|
| 553 |
+
else:
|
| 554 |
+
new_chat_history = ChatHistory()
|
| 555 |
+
|
| 556 |
+
return new_chat_history
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def is_turn_id(text):
|
| 560 |
+
pattern = r'_\d+$'
|
| 561 |
+
return bool(re.search(pattern, text))
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
def embedding_search(chat_history: ChatHistory, qid: str, top_k: int = 50, exclude_sess=None):
|
| 565 |
+
print("\t\t** embedding based retrieval **")
|
| 566 |
+
new_sess_ids = []
|
| 567 |
+
|
| 568 |
+
curr_all_sess = set(chat_history.get_session_ids())
|
| 569 |
+
if exclude_sess:
|
| 570 |
+
curr_all_sess = curr_all_sess - set(exclude_sess)
|
| 571 |
+
|
| 572 |
+
for item in retrieved_data_dict[qid]["retrieval_results"]["ranked_items"]:
|
| 573 |
+
if item["corpus_id"] in valid_sess_set:
|
| 574 |
+
sid = item["corpus_id"]
|
| 575 |
+
else:
|
| 576 |
+
tokens = item["corpus_id"].split("_")
|
| 577 |
+
|
| 578 |
+
if "_turn" in item["corpus_id"]:
|
| 579 |
+
sid = item["corpus_id"].split("_turn")[0]
|
| 580 |
+
elif "_fact" in item["corpus_id"]:
|
| 581 |
+
sid = item["corpus_id"].split("_fact")[0]
|
| 582 |
+
elif "noans" in item["corpus_id"]:
|
| 583 |
+
sid = item["corpus_id"].replace("noans", "answer")
|
| 584 |
+
elif is_turn_id(item["corpus_id"]):
|
| 585 |
+
sid = "_".join(tokens[:-1]) # remove turn index
|
| 586 |
+
else:
|
| 587 |
+
sid = item["corpus_id"]
|
| 588 |
+
|
| 589 |
+
if sid not in valid_sess_set:
|
| 590 |
+
print(item["corpus_id"], sid)
|
| 591 |
+
|
| 592 |
+
assert sid in valid_sess_set
|
| 593 |
+
|
| 594 |
+
if sid in curr_all_sess:
|
| 595 |
+
new_sess_ids.append(sid)
|
| 596 |
+
if len(new_sess_ids) == top_k:
|
| 597 |
+
break
|
| 598 |
+
new_chat_history = chat_history.get_item_by_ranked_session(new_sess_ids)
|
| 599 |
+
return new_chat_history
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
def filter_out_by_embedding(chat_history: ChatHistory, qid: str, top_k: int = 50):
|
| 603 |
+
print("\t\t** [filter_out] embedding based retrieval - loading existing results ...")
|
| 604 |
+
new_sess_ids = []
|
| 605 |
+
curr_all_sess = set(chat_history.get_session_ids())
|
| 606 |
+
for item in retrieved_data_dict[qid]["retrieval_results"]["ranked_items"]:
|
| 607 |
+
if item["corpus_id"] in valid_sess_set:
|
| 608 |
+
sid = item["corpus_id"]
|
| 609 |
+
else:
|
| 610 |
+
tokens = item["corpus_id"].split("_")
|
| 611 |
+
|
| 612 |
+
if "_turn" in item["corpus_id"]:
|
| 613 |
+
sid = item["corpus_id"].split("_turn")[0]
|
| 614 |
+
elif "_fact" in item["corpus_id"]:
|
| 615 |
+
sid = item["corpus_id"].split("_fact")[0]
|
| 616 |
+
elif "noans" in item["corpus_id"]:
|
| 617 |
+
sid = item["corpus_id"].replace("noans", "answer")
|
| 618 |
+
elif is_turn_id(item["corpus_id"]):
|
| 619 |
+
sid = "_".join(tokens[:-1]) # remove turn index
|
| 620 |
+
else:
|
| 621 |
+
sid = item["corpus_id"]
|
| 622 |
+
|
| 623 |
+
if sid not in valid_sess_set:
|
| 624 |
+
print(item["corpus_id"], sid)
|
| 625 |
+
|
| 626 |
+
assert sid in valid_sess_set
|
| 627 |
+
|
| 628 |
+
if sid in curr_all_sess:
|
| 629 |
+
new_sess_ids.append(sid)
|
| 630 |
+
if len(new_sess_ids) == top_k:
|
| 631 |
+
break
|
| 632 |
+
new_chat_history = chat_history.get_item_by_ranked_session(new_sess_ids)
|
| 633 |
+
return new_chat_history, 0.0
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
def flat_embedding_top_k_ids(qid: str, haystack_sess_ids: List[str], top_k: int) -> List[str]:
|
| 637 |
+
"""
|
| 638 |
+
Pull the top_k session IDs from the global GTE retrieval cache (retrieved_data_dict),
|
| 639 |
+
constrained to the question's haystack. Mirrors embedding_search() but operates on
|
| 640 |
+
IDs only (no ChatHistory). Used by hier_union to widen the Stage-2 pool.
|
| 641 |
+
"""
|
| 642 |
+
haystack_set = set(haystack_sess_ids)
|
| 643 |
+
ids: List[str] = []
|
| 644 |
+
for item in retrieved_data_dict[qid]["retrieval_results"]["ranked_items"]:
|
| 645 |
+
cid = item["corpus_id"]
|
| 646 |
+
if cid in valid_sess_set:
|
| 647 |
+
sid = cid
|
| 648 |
+
elif "_turn" in cid:
|
| 649 |
+
sid = cid.split("_turn")[0]
|
| 650 |
+
elif "_fact" in cid:
|
| 651 |
+
sid = cid.split("_fact")[0]
|
| 652 |
+
elif "noans" in cid:
|
| 653 |
+
sid = cid.replace("noans", "answer")
|
| 654 |
+
elif is_turn_id(cid):
|
| 655 |
+
sid = "_".join(cid.split("_")[:-1])
|
| 656 |
+
else:
|
| 657 |
+
sid = cid
|
| 658 |
+
if sid in haystack_set and sid not in ids:
|
| 659 |
+
ids.append(sid)
|
| 660 |
+
if len(ids) == top_k:
|
| 661 |
+
break
|
| 662 |
+
return ids
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
def semantic_embedding_search(
|
| 666 |
+
qid: str,
|
| 667 |
+
haystack_sess_ids: List[str],
|
| 668 |
+
semantic_retrieved_dict: dict,
|
| 669 |
+
top_k: int = 50,
|
| 670 |
+
) -> List[str]:
|
| 671 |
+
"""
|
| 672 |
+
Like embedding_search() but reads from the pre-computed semantic-gte retrieval cache.
|
| 673 |
+
Returns an ordered list of up to top_k session IDs from the haystack.
|
| 674 |
+
"""
|
| 675 |
+
print("\t\t** semantic embedding retrieval **")
|
| 676 |
+
haystack_set = set(haystack_sess_ids)
|
| 677 |
+
ranked_ids: List[str] = []
|
| 678 |
+
for item in semantic_retrieved_dict[qid]["retrieval_results"]["ranked_items"]:
|
| 679 |
+
sid = item["corpus_id"] # already session-level (no turn suffix)
|
| 680 |
+
if sid in haystack_set and sid not in ranked_ids:
|
| 681 |
+
ranked_ids.append(sid)
|
| 682 |
+
if len(ranked_ids) == top_k:
|
| 683 |
+
break
|
| 684 |
+
return ranked_ids
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
def time_filter(chat_history: ChatHistory, start_date: str, end_date: str) -> ChatHistory:
|
| 688 |
+
# Returns all messages with timestamp in the ISO date range [start_date, end_date] (inclusive).
|
| 689 |
+
start_time = time.time()
|
| 690 |
+
try:
|
| 691 |
+
start = datetime.fromisoformat(start_date)
|
| 692 |
+
end = datetime.fromisoformat(end_date)
|
| 693 |
+
filtered_msgs = [msg for msg in chat_history.messages if start.date() <= msg["timestamp"].date() <= end.date()]
|
| 694 |
+
except Exception as e:
|
| 695 |
+
print("Converting date error: ", e)
|
| 696 |
+
filtered_msgs = []
|
| 697 |
+
end_time = time.time()
|
| 698 |
+
execution_time = end_time - start_time
|
| 699 |
+
|
| 700 |
+
if filtered_msgs:
|
| 701 |
+
new_sess_ids = set()
|
| 702 |
+
for msg in filtered_msgs:
|
| 703 |
+
key = msg["session_id"]
|
| 704 |
+
new_sess_ids.add(key)
|
| 705 |
+
new_chat_history = chat_history.get_item_by_session_ids(new_sess_ids)
|
| 706 |
+
else:
|
| 707 |
+
new_chat_history = ChatHistory()
|
| 708 |
+
|
| 709 |
+
return new_chat_history
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
class RetrievalAgent:
|
| 713 |
+
def __init__(
|
| 714 |
+
self,
|
| 715 |
+
history: List[Dict],
|
| 716 |
+
topics: List[str],
|
| 717 |
+
user_profile: str = None,
|
| 718 |
+
debug: bool = False,
|
| 719 |
+
vllm: bool = False,
|
| 720 |
+
vllm_reading: bool = False,
|
| 721 |
+
tritonai: bool = False,
|
| 722 |
+
nvidia: bool = False,
|
| 723 |
+
n_chunks: int = 10,
|
| 724 |
+
topic_filter: bool = True,
|
| 725 |
+
no_time_filter: bool = False,
|
| 726 |
+
semantic_store: SemanticMemoryStore = None,
|
| 727 |
+
episodic_store: EpisodicMemoryStore = None,
|
| 728 |
+
hier_v2: bool = False,
|
| 729 |
+
hier_union: bool = False,
|
| 730 |
+
hier_union_flat_k: int = 20,
|
| 731 |
+
no_early_answer: bool = False,
|
| 732 |
+
):
|
| 733 |
+
self.chat_history = history
|
| 734 |
+
self.user_profile = user_profile
|
| 735 |
+
self.topics = topics
|
| 736 |
+
self.rel_sess = ChatHistory()
|
| 737 |
+
self.evidence = []
|
| 738 |
+
self.debug = debug
|
| 739 |
+
self.vllm = vllm
|
| 740 |
+
self.vllm_reading = vllm_reading # use vLLM only for _read_and_verify
|
| 741 |
+
self.tritonai = tritonai # use LiteLLM proxy for non-reading LLM calls
|
| 742 |
+
self.nvidia = nvidia # use NVIDIA inference API
|
| 743 |
+
self.no_time_filter = no_time_filter # skip time_filter steps in strategy
|
| 744 |
+
self.n_chunks = n_chunks
|
| 745 |
+
self.topic_filter = topic_filter
|
| 746 |
+
self.semantic_store = semantic_store
|
| 747 |
+
self.episodic_store = episodic_store
|
| 748 |
+
self.hier_v2 = hier_v2
|
| 749 |
+
self.hier_union = hier_union
|
| 750 |
+
self.hier_union_flat_k = hier_union_flat_k
|
| 751 |
+
self.no_early_answer = no_early_answer
|
| 752 |
+
self.token_budget = {
|
| 753 |
+
'planning': {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0},
|
| 754 |
+
'verification_reading': {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0},
|
| 755 |
+
'is_answerable': {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0},
|
| 756 |
+
'final_answer': {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0},
|
| 757 |
+
}
|
| 758 |
+
|
| 759 |
+
if self.user_profile:
|
| 760 |
+
with open('prompts/read_and_extract_prompt.txt') as f:
|
| 761 |
+
self.read_prompt_template = f.read()
|
| 762 |
+
else: # ablation: wo_profile
|
| 763 |
+
with open('prompts/agentic_retrieval_prompt_wo_profile.txt') as f:
|
| 764 |
+
self.read_prompt_template = f.read()
|
| 765 |
+
|
| 766 |
+
def _track_usage(self, component: str, completion) -> None:
|
| 767 |
+
"""Accumulate prompt/completion token counts for a named component."""
|
| 768 |
+
usage = getattr(completion, 'usage', None)
|
| 769 |
+
if usage is None:
|
| 770 |
+
return
|
| 771 |
+
self.token_budget[component]['prompt_tokens'] += getattr(usage, 'prompt_tokens', 0) or 0
|
| 772 |
+
self.token_budget[component]['completion_tokens'] += getattr(usage, 'completion_tokens', 0) or 0
|
| 773 |
+
self.token_budget[component]['n_calls'] += 1
|
| 774 |
+
|
| 775 |
+
def get_token_budget(self) -> dict:
|
| 776 |
+
"""Return token_budget with an added 'total' entry."""
|
| 777 |
+
total = {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0}
|
| 778 |
+
for v in self.token_budget.values():
|
| 779 |
+
total['prompt_tokens'] += v['prompt_tokens']
|
| 780 |
+
total['completion_tokens'] += v['completion_tokens']
|
| 781 |
+
total['n_calls'] += v['n_calls']
|
| 782 |
+
return {**self.token_budget, 'total': total}
|
| 783 |
+
|
| 784 |
+
def is_answerable(self, question: str, question_date: str, retrieved_sess, evidence, model_info, context_str: str = None) -> bool:
|
| 785 |
+
context = ""
|
| 786 |
+
for k, v in evidence.items(): # key: "profile", "tags", "chat_clues"
|
| 787 |
+
for e in v:
|
| 788 |
+
context += f"{e}\n"
|
| 789 |
+
|
| 790 |
+
# context_str overrides retrieved_sess.to_prompt() (used in Stage 1 with semantic context)
|
| 791 |
+
sess_str = context_str if context_str is not None else retrieved_sess.to_prompt()
|
| 792 |
+
|
| 793 |
+
# Include user profile in the prompt when available
|
| 794 |
+
profile_section = ""
|
| 795 |
+
if self.user_profile:
|
| 796 |
+
profile_section = f"\nUser Profile:\n{self.user_profile}\n"
|
| 797 |
+
|
| 798 |
+
ia_prompt_prefix = f"""
|
| 799 |
+
You are a decision-making agent tasked with determining when sufficient information has been gathered to answer a user's question.
|
| 800 |
+
|
| 801 |
+
Your Task:
|
| 802 |
+
Analyze the provided question, current date, available memory context, and available evidence to make a binary decision: Answerable or Not answerable. If the information is not sufficient, explain what specific information is needed to provide hints for the next retrieval stage.
|
| 803 |
+
|
| 804 |
+
Question: {question}
|
| 805 |
+
Current Date: {question_date}
|
| 806 |
+
{profile_section}
|
| 807 |
+
Memory Context:
|
| 808 |
+
"""
|
| 809 |
+
|
| 810 |
+
output_str = """
|
| 811 |
+
Output (always JSON — choose fields per the rules above)
|
| 812 |
+
|
| 813 |
+
Case 1 — Answerable:
|
| 814 |
+
{
|
| 815 |
+
"is_answerable": true,
|
| 816 |
+
"answer": "<concise answer grounded strictly in the Evidence>"
|
| 817 |
+
}
|
| 818 |
+
|
| 819 |
+
Case 2 — Not answerable:
|
| 820 |
+
{
|
| 821 |
+
"is_answerable": false,
|
| 822 |
+
"info_needed": ["<specific missing detail 1>", "<specific missing detail 2>"]
|
| 823 |
+
}
|
| 824 |
+
"""
|
| 825 |
+
deployment_name, api_version = model_info
|
| 826 |
+
|
| 827 |
+
# ------------------------------------------------------------------
|
| 828 |
+
# Token-based truncation: keep the *end* of sess_str under budget
|
| 829 |
+
# ------------------------------------------------------------------
|
| 830 |
+
enc = _get_encoder(deployment_name)
|
| 831 |
+
|
| 832 |
+
# Model/context limits
|
| 833 |
+
if self.vllm or self.tritonai:
|
| 834 |
+
# Large context for vLLM / LiteLLM proxy models
|
| 835 |
+
model_max_ctx = 131_072
|
| 836 |
+
else:
|
| 837 |
+
model_max_ctx = MAX_CONTEXT_TOKENS
|
| 838 |
+
|
| 839 |
+
max_output_tokens = 1024
|
| 840 |
+
extra_overhead_tokens = 32
|
| 841 |
+
|
| 842 |
+
# Total budget available for input tokens
|
| 843 |
+
budget = model_max_ctx - max_output_tokens - extra_overhead_tokens
|
| 844 |
+
if budget <= 0:
|
| 845 |
+
raise ValueError(
|
| 846 |
+
f"max_output_tokens ({max_output_tokens}) + overhead "
|
| 847 |
+
f"({extra_overhead_tokens}) exceeds model_max_ctx ({model_max_ctx})"
|
| 848 |
+
)
|
| 849 |
+
|
| 850 |
+
# Token lengths of static pieces
|
| 851 |
+
prefix_tokens = enc.encode(ia_prompt_prefix, disallowed_special=())
|
| 852 |
+
output_tokens = enc.encode(output_str, disallowed_special=())
|
| 853 |
+
sess_tokens = enc.encode(sess_str, disallowed_special=())
|
| 854 |
+
|
| 855 |
+
# Budget for sess_str tokens
|
| 856 |
+
available_for_sess = budget - len(prefix_tokens) - len(output_tokens)
|
| 857 |
+
|
| 858 |
+
if available_for_sess <= 0:
|
| 859 |
+
# No room for history at all; drop it
|
| 860 |
+
truncated_sess_str = ""
|
| 861 |
+
else:
|
| 862 |
+
if len(sess_tokens) > available_for_sess:
|
| 863 |
+
# Keep the *last* available_for_sess tokens (drop oldest history)
|
| 864 |
+
truncated_sess_tokens = sess_tokens[-available_for_sess:]
|
| 865 |
+
truncated_sess_str = enc.decode(truncated_sess_tokens)
|
| 866 |
+
else:
|
| 867 |
+
truncated_sess_str = sess_str
|
| 868 |
+
|
| 869 |
+
ia_prompt = ia_prompt_prefix + truncated_sess_str + "\n"
|
| 870 |
+
|
| 871 |
+
# ------------------------------------------------------------------
|
| 872 |
+
# Call the LLM with the already-truncated prompt
|
| 873 |
+
# ------------------------------------------------------------------
|
| 874 |
+
|
| 875 |
+
completion = llm_call(
|
| 876 |
+
deployment_name,
|
| 877 |
+
api_version,
|
| 878 |
+
ia_prompt + output_str,
|
| 879 |
+
max_context_tokens=model_max_ctx, # matches what we used for budgeting
|
| 880 |
+
max_output_tokens=max_output_tokens,
|
| 881 |
+
extra_overhead_tokens=extra_overhead_tokens,
|
| 882 |
+
debug=self.debug,
|
| 883 |
+
vllm=self.vllm,
|
| 884 |
+
tritonai=self.tritonai,
|
| 885 |
+
nvidia=self.nvidia,
|
| 886 |
+
)
|
| 887 |
+
self._track_usage('is_answerable', completion)
|
| 888 |
+
response_content = (completion.choices[0].message.content or "").strip()
|
| 889 |
+
print(f"\t\t[Agent] is_answerable: {response_content}")
|
| 890 |
+
result = parse_json(response_content)
|
| 891 |
+
if not result:
|
| 892 |
+
print("[Warning] Empty or invalid JSON in is_answerable() response.")
|
| 893 |
+
return {"is_answerable": False, "info_needed": ["Parsing failed"]}
|
| 894 |
+
return result
|
| 895 |
+
|
| 896 |
+
def _read_and_verify(self, question: str, question_date: str, evidence: ChatHistory, n_chunks=10) -> ChatHistory:
|
| 897 |
+
# Read evidence and and select
|
| 898 |
+
relevant_indices = []
|
| 899 |
+
evidence_list = []
|
| 900 |
+
max_idx = len(evidence)
|
| 901 |
+
for j in range(0, len(evidence), n_chunks):
|
| 902 |
+
chunk_range = range(j, j+n_chunks)
|
| 903 |
+
valid_indices = [i for i in chunk_range if 0 <= i < max_idx]
|
| 904 |
+
cur_chunk = evidence.get_item_by_index(valid_indices)
|
| 905 |
+
cur_chunk_sess = [[{"role": m["role"], "content": m["content"]} for m in sess['session']]
|
| 906 |
+
for sess in cur_chunk.sessions]
|
| 907 |
+
cur_chunk_sess_date = [sess['session_date'] for sess in cur_chunk.sessions]
|
| 908 |
+
sess_input_str = "\n".join([
|
| 909 |
+
f"### Session Index: {i}\n### Session Date: {sess_date}\n\n{json.dumps(sess)}\n"
|
| 910 |
+
for i, (sess, sess_date) in enumerate(zip(cur_chunk_sess, cur_chunk_sess_date))
|
| 911 |
+
])
|
| 912 |
+
_prompt = self.read_prompt_template + f"## Question: {question}\n## Question Date: {question_date}\n## Session list:\n\n{sess_input_str}\nNow, identify **only the sessions strictly necessary to answer the question**."
|
| 913 |
+
completion = llm_call(
|
| 914 |
+
deployment_name,
|
| 915 |
+
api_version,
|
| 916 |
+
_prompt,
|
| 917 |
+
debug=self.debug,
|
| 918 |
+
vllm=self.vllm or self.vllm_reading,
|
| 919 |
+
nvidia=self.nvidia,
|
| 920 |
+
)
|
| 921 |
+
self._track_usage('verification_reading', completion)
|
| 922 |
+
response_content = (completion.choices[0].message.content or "").strip()
|
| 923 |
+
|
| 924 |
+
print(f"\t\t {valid_indices[0]}~{valid_indices[-1]}: response: {response_content.replace(chr(10), '')}")
|
| 925 |
+
try:
|
| 926 |
+
start_idx = response_content.rfind('{')
|
| 927 |
+
end_idx = response_content.rfind('}') + 1
|
| 928 |
+
json_block = response_content[start_idx:end_idx]
|
| 929 |
+
result = json.loads(json_block)
|
| 930 |
+
|
| 931 |
+
if "index" in result and result['index'] and 'evidence' and result:
|
| 932 |
+
relevant_indices.extend([j + idx for idx in result['index']])
|
| 933 |
+
evidence_list.extend(result['evidence'])
|
| 934 |
+
except Exception as e:
|
| 935 |
+
print(f"Error parsing LLM response: {e}")
|
| 936 |
+
|
| 937 |
+
if relevant_indices:
|
| 938 |
+
return evidence.get_item_by_index(relevant_indices), evidence_list
|
| 939 |
+
else:
|
| 940 |
+
return ChatHistory(), []
|
| 941 |
+
|
| 942 |
+
def _read_and_verify_with_cache(self, qid:str, pool):
|
| 943 |
+
relevant_sess_ids = []
|
| 944 |
+
|
| 945 |
+
for sess in pool.sessions:
|
| 946 |
+
sess_id = sess['session_id']
|
| 947 |
+
if sess_id in qid2rel_sess_ids[qid]:
|
| 948 |
+
relevant_sess_ids.append(sess_id)
|
| 949 |
+
|
| 950 |
+
if len(relevant_sess_ids) > 0:
|
| 951 |
+
return pool.get_item_by_session_ids(relevant_sess_ids), []
|
| 952 |
+
else:
|
| 953 |
+
return ChatHistory(), []
|
| 954 |
+
|
| 955 |
+
def _plan(self, query: str, query_date: str, attempt_record: list, model_info) -> str:
|
| 956 |
+
if self.user_profile:
|
| 957 |
+
template = """
|
| 958 |
+
### User profile: {user_profile}
|
| 959 |
+
### Chat history topics: {topics}
|
| 960 |
+
### User query: {query}
|
| 961 |
+
### User query date: {query_date}
|
| 962 |
+
### Previous attempts:
|
| 963 |
+
{strategies_info}
|
| 964 |
+
"""
|
| 965 |
+
else:
|
| 966 |
+
template = """
|
| 967 |
+
### Chat history topics: {topics}
|
| 968 |
+
### User query: {query}
|
| 969 |
+
### User query date: {query_date}
|
| 970 |
+
### Previous attempts:
|
| 971 |
+
{strategies_info}
|
| 972 |
+
"""
|
| 973 |
+
|
| 974 |
+
if attempt_record:
|
| 975 |
+
strategies_info = ""
|
| 976 |
+
for loop_num, entry in enumerate(attempt_record):
|
| 977 |
+
strategies_info += f"\nloop_iteration: {loop_num+1}\n"
|
| 978 |
+
strategies_info += "\n".join(entry.get('step_logs', []))
|
| 979 |
+
if 'n_retrieved_sess' in entry and 'evidence' in entry:
|
| 980 |
+
strategies_info += f"Retrieved {entry['n_retrieved_sess']} docs, observed_evidence: {entry['evidence']}"
|
| 981 |
+
if entry['n_retrieved_sess'] == 0:
|
| 982 |
+
strategies_info += f"Additional Instruction: Re-try without filter methods if the previous paln includes topics or time-filtering\n"
|
| 983 |
+
else:
|
| 984 |
+
strategies_info = "(No previous attempt exists)"
|
| 985 |
+
|
| 986 |
+
if self.user_profile:
|
| 987 |
+
prompt_filled = stg_prompt + template.format(
|
| 988 |
+
user_profile=self.user_profile,
|
| 989 |
+
topics=",".join(self.topics),
|
| 990 |
+
query=query,
|
| 991 |
+
query_date=query_date,
|
| 992 |
+
strategies_info=strategies_info)
|
| 993 |
+
else:
|
| 994 |
+
prompt_filled = stg_prompt + template.format(
|
| 995 |
+
topics=",".join(self.topics),
|
| 996 |
+
query=query,
|
| 997 |
+
query_date=query_date,
|
| 998 |
+
strategies_info=strategies_info)
|
| 999 |
+
|
| 1000 |
+
deployment_name, api_version = model_info
|
| 1001 |
+
completion = llm_call(
|
| 1002 |
+
deployment_name,
|
| 1003 |
+
api_version,
|
| 1004 |
+
prompt_filled,
|
| 1005 |
+
debug=self.debug,
|
| 1006 |
+
vllm=self.vllm,
|
| 1007 |
+
tritonai=self.tritonai,
|
| 1008 |
+
nvidia=self.nvidia,
|
| 1009 |
+
)
|
| 1010 |
+
self._track_usage('planning', completion)
|
| 1011 |
+
response_content = (completion.choices[0].message.content or "").strip()
|
| 1012 |
+
_plan = parse_json(response_content)
|
| 1013 |
+
if not _plan:
|
| 1014 |
+
print("[Warning] Failed to parse plan JSON — retrying once.")
|
| 1015 |
+
completion = llm_call(
|
| 1016 |
+
deployment_name,
|
| 1017 |
+
api_version,
|
| 1018 |
+
prompt_filled,
|
| 1019 |
+
debug=self.debug,
|
| 1020 |
+
vllm=self.vllm,
|
| 1021 |
+
tritonai=self.tritonai,
|
| 1022 |
+
nvidia=self.nvidia,
|
| 1023 |
+
)
|
| 1024 |
+
self._track_usage('planning', completion)
|
| 1025 |
+
response_content = (completion.choices[0].message.content or "").strip()
|
| 1026 |
+
_plan = parse_json(response_content)
|
| 1027 |
+
if not _plan:
|
| 1028 |
+
print("[Warning] Failed to parse plan JSON after retry — returning fallback plan.")
|
| 1029 |
+
_plan = {"answer": "none", "reason": "invalid JSON response", "topics": [], "strategy": []}
|
| 1030 |
+
return _plan
|
| 1031 |
+
|
| 1032 |
+
def _run_stage1(
|
| 1033 |
+
self,
|
| 1034 |
+
qid: str,
|
| 1035 |
+
question: str,
|
| 1036 |
+
question_date: str,
|
| 1037 |
+
top_k: int,
|
| 1038 |
+
model_info,
|
| 1039 |
+
haystack_sess_ids: List[str],
|
| 1040 |
+
date_lookup: Dict[str, str],
|
| 1041 |
+
semantic_ret_dict: dict,
|
| 1042 |
+
) -> dict:
|
| 1043 |
+
"""
|
| 1044 |
+
Stage 1: retrieve and evaluate using semantic memory only (summaries + facts).
|
| 1045 |
+
|
| 1046 |
+
Returns a dict with:
|
| 1047 |
+
is_answerable : bool
|
| 1048 |
+
answer : str | None (set when is_answerable is True)
|
| 1049 |
+
candidate_ids : list[str] (top-K session IDs from semantic retrieval)
|
| 1050 |
+
attempt_record : list
|
| 1051 |
+
"""
|
| 1052 |
+
print(f"\t[Stage 1] Semantic memory retrieval for qid={qid}")
|
| 1053 |
+
|
| 1054 |
+
# --- 1a. Semantic embedding search ---
|
| 1055 |
+
candidate_ids = semantic_embedding_search(
|
| 1056 |
+
qid, haystack_sess_ids, semantic_ret_dict, top_k=top_k
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
# hier_v2: skip plan / keyword / time_filter / is_answerable. Stage 1 is candidate-only.
|
| 1060 |
+
if self.hier_v2:
|
| 1061 |
+
print(f"\t[Stage 1 hier_v2] embedding-only candidates: {len(candidate_ids)}")
|
| 1062 |
+
return {
|
| 1063 |
+
"is_answerable": False,
|
| 1064 |
+
"answer": None,
|
| 1065 |
+
"candidate_ids": candidate_ids[:top_k],
|
| 1066 |
+
"attempt_record": [{
|
| 1067 |
+
"stage": "semantic_v2",
|
| 1068 |
+
"plan": {},
|
| 1069 |
+
"n_candidates": len(candidate_ids),
|
| 1070 |
+
"candidate_ids": candidate_ids[:top_k],
|
| 1071 |
+
}],
|
| 1072 |
+
}
|
| 1073 |
+
|
| 1074 |
+
# --- 1b. Plan for keywords / time filter (reuse existing _plan) ---
|
| 1075 |
+
plan = self._plan(question, question_date, [], model_info)
|
| 1076 |
+
print(json.dumps(plan, indent=4), flush=True)
|
| 1077 |
+
|
| 1078 |
+
# If the planner already has a direct answer, return it
|
| 1079 |
+
if "answer" in plan and plan["answer"].lower() != "none":
|
| 1080 |
+
return {
|
| 1081 |
+
"is_answerable": True,
|
| 1082 |
+
"answer": plan["answer"],
|
| 1083 |
+
"candidate_ids": candidate_ids,
|
| 1084 |
+
"attempt_record": [{"plan": plan, "stage": "semantic"}],
|
| 1085 |
+
}
|
| 1086 |
+
|
| 1087 |
+
# --- 1c. Semantic keyword search ---
|
| 1088 |
+
keyword_ids: List[str] = []
|
| 1089 |
+
for step in plan.get("strategy", []):
|
| 1090 |
+
if step.get("method") == "keyword":
|
| 1091 |
+
kws = step.get("keywords", [])
|
| 1092 |
+
matched = self.semantic_store.keyword_search(kws, haystack_sess_ids)
|
| 1093 |
+
print(f"\t\t** semantic keyword search **: {kws} -> {len(matched)} matches")
|
| 1094 |
+
keyword_ids.extend(sid for sid in matched if sid not in keyword_ids)
|
| 1095 |
+
|
| 1096 |
+
# --- 1d. Time filter on candidate set ---
|
| 1097 |
+
for step in plan.get("strategy", []):
|
| 1098 |
+
if self.no_time_filter:
|
| 1099 |
+
break # skip all time_filter steps
|
| 1100 |
+
if step.get("method") == "time_filter":
|
| 1101 |
+
if "time_range" not in step or len(step["time_range"]) != 2:
|
| 1102 |
+
continue
|
| 1103 |
+
start_str, end_str = step["time_range"]
|
| 1104 |
+
from datetime import datetime
|
| 1105 |
+
try:
|
| 1106 |
+
start_dt = datetime.fromisoformat(start_str)
|
| 1107 |
+
end_dt = datetime.fromisoformat(end_str)
|
| 1108 |
+
candidate_ids = [
|
| 1109 |
+
sid for sid in candidate_ids
|
| 1110 |
+
if sid in date_lookup and
|
| 1111 |
+
start_dt.date() <= EpisodicMemoryStore._parse_date(date_lookup[sid]).date() <= end_dt.date()
|
| 1112 |
+
]
|
| 1113 |
+
keyword_ids = [
|
| 1114 |
+
sid for sid in keyword_ids
|
| 1115 |
+
if sid in date_lookup and
|
| 1116 |
+
start_dt.date() <= EpisodicMemoryStore._parse_date(date_lookup[sid]).date() <= end_dt.date()
|
| 1117 |
+
]
|
| 1118 |
+
print(f"\t\t** semantic time_filter **: {start_str}..{end_str} -> "
|
| 1119 |
+
f"{len(candidate_ids)} embed, {len(keyword_ids)} keyword")
|
| 1120 |
+
except Exception as e:
|
| 1121 |
+
print(f"\t\t[WARN] time_filter parse error: {e}")
|
| 1122 |
+
|
| 1123 |
+
# --- 1e. Merge candidates (keyword union with embedding, preserve rank) ---
|
| 1124 |
+
all_candidate_ids: List[str] = list(candidate_ids)
|
| 1125 |
+
for sid in keyword_ids:
|
| 1126 |
+
if sid not in all_candidate_ids:
|
| 1127 |
+
all_candidate_ids.append(sid)
|
| 1128 |
+
|
| 1129 |
+
# Cap to top_k
|
| 1130 |
+
all_candidate_ids = all_candidate_ids[:top_k]
|
| 1131 |
+
|
| 1132 |
+
if not all_candidate_ids:
|
| 1133 |
+
print("\t[Stage 1] No candidates found in semantic memory.")
|
| 1134 |
+
return {
|
| 1135 |
+
"is_answerable": False,
|
| 1136 |
+
"answer": None,
|
| 1137 |
+
"candidate_ids": [],
|
| 1138 |
+
"attempt_record": [{"plan": plan, "stage": "semantic", "n_candidates": 0}],
|
| 1139 |
+
}
|
| 1140 |
+
|
| 1141 |
+
# --- 1f. Build semantic context string ---
|
| 1142 |
+
semantic_context_str = self.semantic_store.to_prompt(all_candidate_ids, date_lookup)
|
| 1143 |
+
print(f"\t[Stage 1] Built semantic context for {len(all_candidate_ids)} sessions "
|
| 1144 |
+
f"({len(semantic_context_str)} chars)")
|
| 1145 |
+
|
| 1146 |
+
# --- 1g. is_answerable check on semantic context ---
|
| 1147 |
+
accumulated_evidence = {"profile": [], "chat_clues": []}
|
| 1148 |
+
answerable_response = self.is_answerable(
|
| 1149 |
+
question, question_date,
|
| 1150 |
+
retrieved_sess=None,
|
| 1151 |
+
evidence=accumulated_evidence,
|
| 1152 |
+
model_info=model_info,
|
| 1153 |
+
context_str=semantic_context_str,
|
| 1154 |
+
)
|
| 1155 |
+
print(f"\t[Stage 1] is_answerable: {answerable_response}")
|
| 1156 |
+
|
| 1157 |
+
attempt_record = [{
|
| 1158 |
+
"stage": "semantic",
|
| 1159 |
+
"plan": plan,
|
| 1160 |
+
"n_candidates": len(all_candidate_ids),
|
| 1161 |
+
"candidate_ids": all_candidate_ids,
|
| 1162 |
+
"is_answerable": answerable_response.get("is_answerable", False),
|
| 1163 |
+
}]
|
| 1164 |
+
|
| 1165 |
+
if answerable_response.get("is_answerable"):
|
| 1166 |
+
return {
|
| 1167 |
+
"is_answerable": True,
|
| 1168 |
+
"answer": answerable_response.get("answer"),
|
| 1169 |
+
"candidate_ids": all_candidate_ids,
|
| 1170 |
+
"attempt_record": attempt_record,
|
| 1171 |
+
}
|
| 1172 |
+
|
| 1173 |
+
print(f"\t[Stage 1] Not answerable from semantic memory. "
|
| 1174 |
+
f"Info needed: {answerable_response.get('info_needed', [])}")
|
| 1175 |
+
return {
|
| 1176 |
+
"is_answerable": False,
|
| 1177 |
+
"answer": None,
|
| 1178 |
+
"candidate_ids": all_candidate_ids,
|
| 1179 |
+
"attempt_record": attempt_record,
|
| 1180 |
+
}
|
| 1181 |
+
|
| 1182 |
+
def run(self, qid:str, question: str, question_date: str, top_k: int, model_info, max_loops=3,
|
| 1183 |
+
semantic_ret_dict: dict = None, haystack_sess_ids: List[str] = None,
|
| 1184 |
+
date_lookup: Dict[str, str] = None, topic_lookup: Dict[str, List[str]] = None):
|
| 1185 |
+
accumulated_evidence = {"profile": [], "chat_clues": []}
|
| 1186 |
+
attempt_record = []
|
| 1187 |
+
loop_num = 0
|
| 1188 |
+
|
| 1189 |
+
# ----------------------------------------------------------------
|
| 1190 |
+
# Stage 1: Semantic memory — only runs when stores are provided
|
| 1191 |
+
# ----------------------------------------------------------------
|
| 1192 |
+
stage1_candidate_ids: List[str] = []
|
| 1193 |
+
if (self.semantic_store is not None
|
| 1194 |
+
and semantic_ret_dict is not None
|
| 1195 |
+
and haystack_sess_ids is not None):
|
| 1196 |
+
stage1_result = self._run_stage1(
|
| 1197 |
+
qid, question, question_date, top_k, model_info,
|
| 1198 |
+
haystack_sess_ids, date_lookup or {}, semantic_ret_dict,
|
| 1199 |
+
)
|
| 1200 |
+
attempt_record.extend(stage1_result["attempt_record"])
|
| 1201 |
+
stage1_candidate_ids = stage1_result["candidate_ids"]
|
| 1202 |
+
|
| 1203 |
+
if stage1_result["is_answerable"] and not self.no_early_answer:
|
| 1204 |
+
print(f"\t[Stage 1] Answered from semantic memory.")
|
| 1205 |
+
# Wrap answer in attempt_record format expected by caller
|
| 1206 |
+
if "answer" in stage1_result and stage1_result["answer"]:
|
| 1207 |
+
attempt_record[0]["plan"] = {
|
| 1208 |
+
**attempt_record[0].get("plan", {}),
|
| 1209 |
+
"answer": stage1_result["answer"],
|
| 1210 |
+
}
|
| 1211 |
+
return ChatHistory(), attempt_record
|
| 1212 |
+
if stage1_result["is_answerable"] and self.no_early_answer:
|
| 1213 |
+
print(f"\t[Stage 1] is_answerable=True but --no_early_answer set; proceeding to Stage-2.")
|
| 1214 |
+
|
| 1215 |
+
# hier_union: widen Stage-2 pool with flat-embedding top-K from the global GTE cache.
|
| 1216 |
+
# This makes hier a strict superset of flat by construction; targets the recall gap
|
| 1217 |
+
# (semantic-over-summary embeddings rank worse than full-session embeddings).
|
| 1218 |
+
if self.hier_union and qid in retrieved_data_dict:
|
| 1219 |
+
flat_ids = flat_embedding_top_k_ids(qid, haystack_sess_ids, self.hier_union_flat_k)
|
| 1220 |
+
before = len(stage1_candidate_ids)
|
| 1221 |
+
for sid in flat_ids:
|
| 1222 |
+
if sid not in stage1_candidate_ids:
|
| 1223 |
+
stage1_candidate_ids.append(sid)
|
| 1224 |
+
print(f"\t[hier_union] semantic_top_k={before} + flat_top_{self.hier_union_flat_k}={len(flat_ids)} -> union={len(stage1_candidate_ids)}")
|
| 1225 |
+
attempt_record.append({
|
| 1226 |
+
"stage": "hier_union",
|
| 1227 |
+
"plan": {},
|
| 1228 |
+
"n_semantic": before,
|
| 1229 |
+
"n_flat": len(flat_ids),
|
| 1230 |
+
"n_union": len(stage1_candidate_ids),
|
| 1231 |
+
})
|
| 1232 |
+
|
| 1233 |
+
# Stage 2: load episodic sessions only for top-K candidates
|
| 1234 |
+
if self.episodic_store is not None and stage1_candidate_ids:
|
| 1235 |
+
print(f"\t[Stage 2] Loading episodic memory for "
|
| 1236 |
+
f"{len(stage1_candidate_ids)} candidate sessions.")
|
| 1237 |
+
raw_sessions = self.episodic_store.get_raw_sessions(
|
| 1238 |
+
stage1_candidate_ids, date_lookup or {}, topic_lookup
|
| 1239 |
+
)
|
| 1240 |
+
self.chat_history = ChatHistory(sessions=raw_sessions)
|
| 1241 |
+
print(f"\t[Stage 2] Loaded {len(self.chat_history)} sessions into episodic pool.")
|
| 1242 |
+
|
| 1243 |
+
# hier_v2: skip the agent loop. Use raw turns of the candidate sessions directly.
|
| 1244 |
+
# Rationale: the agent loop's verification can reject all of a 20-session pool,
|
| 1245 |
+
# leaving empty retrieved. Strong models answer better from raw turns of K=20
|
| 1246 |
+
# semantic-selected sessions than from over-aggressive verification.
|
| 1247 |
+
if self.hier_v2:
|
| 1248 |
+
print(f"\t[hier_v2] bypassing agent loop; returning {len(self.chat_history)} candidate sessions as retrieved")
|
| 1249 |
+
return self.chat_history, attempt_record
|
| 1250 |
+
|
| 1251 |
+
pool = self.chat_history
|
| 1252 |
+
retrieved = ChatHistory()
|
| 1253 |
+
|
| 1254 |
+
while loop_num < max_loops:
|
| 1255 |
+
loop_num += 1
|
| 1256 |
+
if loop_num == 1 and qid in qid2plan:
|
| 1257 |
+
plan = qid2plan[qid]
|
| 1258 |
+
else:
|
| 1259 |
+
plan = self._plan(question, question_date, attempt_record, model_info)
|
| 1260 |
+
qid2plan[qid] = plan
|
| 1261 |
+
print(json.dumps(plan, indent=4), flush=True)
|
| 1262 |
+
|
| 1263 |
+
if "answer" in plan and not ("none" in plan["answer"].lower()):
|
| 1264 |
+
print(f"{qid}\t{question}\t{plan['answer']}", flush=True)
|
| 1265 |
+
return ChatHistory(), [{"plan": plan}]
|
| 1266 |
+
|
| 1267 |
+
# 2) Execute the plan -> retrieve candidates
|
| 1268 |
+
try:
|
| 1269 |
+
candidates, step_logs = self._execute_strategy(pool, plan, question)
|
| 1270 |
+
except Exception as e:
|
| 1271 |
+
print(f"[Error] Failed during _execute_strategy: {e}")
|
| 1272 |
+
candidates, step_logs = ChatHistory(), [f"Execution failed: {e}"]
|
| 1273 |
+
|
| 1274 |
+
if len(candidates) == 0:
|
| 1275 |
+
attempt_record.append({
|
| 1276 |
+
"loop_iteration": loop_num,
|
| 1277 |
+
"plan": plan,
|
| 1278 |
+
"evidence": accumulated_evidence,
|
| 1279 |
+
"n_candidates_sess": len(candidates),
|
| 1280 |
+
"n_verified_sess": 0,
|
| 1281 |
+
"n_pool": len(pool),
|
| 1282 |
+
"step_logs": step_logs,
|
| 1283 |
+
})
|
| 1284 |
+
continue
|
| 1285 |
+
else:
|
| 1286 |
+
remaining = set(pool.get_session_ids()) - set(candidates.get_session_ids())
|
| 1287 |
+
pool = self.chat_history.get_item_by_session_ids(remaining)
|
| 1288 |
+
|
| 1289 |
+
# 3) Verification Reading
|
| 1290 |
+
if qid in qid2rel_sess_ids:
|
| 1291 |
+
verified, evidence_list = self._read_and_verify_with_cache(qid, candidates)
|
| 1292 |
+
else:
|
| 1293 |
+
verified, evidence_list = self._read_and_verify(question, question_date, candidates, n_chunks=self.n_chunks)
|
| 1294 |
+
qid2rel_sess_ids[qid] = verified.get_session_ids()
|
| 1295 |
+
|
| 1296 |
+
if len(verified) == 0:
|
| 1297 |
+
attempt_record.append({
|
| 1298 |
+
"loop_iteration": loop_num,
|
| 1299 |
+
"plan": plan,
|
| 1300 |
+
"evidence": accumulated_evidence,
|
| 1301 |
+
"n_candidates_sess": len(candidates),
|
| 1302 |
+
"candidates_sess_ids": candidates.get_session_ids(),
|
| 1303 |
+
"n_verified_sess": len(verified),
|
| 1304 |
+
"verified_sess_ids": [],
|
| 1305 |
+
"n_pool": len(pool),
|
| 1306 |
+
"step_logs": step_logs,
|
| 1307 |
+
})
|
| 1308 |
+
continue
|
| 1309 |
+
else:
|
| 1310 |
+
retrieved.merge_rel_sess(verified)
|
| 1311 |
+
|
| 1312 |
+
for ev in evidence_list:
|
| 1313 |
+
if ev not in accumulated_evidence['chat_clues']:
|
| 1314 |
+
accumulated_evidence['chat_clues'].append(ev)
|
| 1315 |
+
|
| 1316 |
+
attempt_record.append({
|
| 1317 |
+
"loop_iteration": loop_num,
|
| 1318 |
+
"plan": plan,
|
| 1319 |
+
"evidence": accumulated_evidence,
|
| 1320 |
+
"n_candidates_sess": len(candidates),
|
| 1321 |
+
"candidates_sess_ids": candidates.get_session_ids(),
|
| 1322 |
+
"n_verified_sess": len(verified),
|
| 1323 |
+
"verified_sess_ids": verified.get_session_ids(),
|
| 1324 |
+
"n_retrieved_sess": len(retrieved),
|
| 1325 |
+
"retrieved_sess_ids": retrieved.get_session_ids(),
|
| 1326 |
+
"n_pool": len(pool),
|
| 1327 |
+
"step_logs": step_logs,
|
| 1328 |
+
})
|
| 1329 |
+
|
| 1330 |
+
# 4) Decide if continue or not
|
| 1331 |
+
if len(retrieved) > top_k:
|
| 1332 |
+
retrieved, _ = filter_out_by_embedding(retrieved, qid=qid, top_k=top_k)
|
| 1333 |
+
|
| 1334 |
+
answerable_response = self.is_answerable(question, question_date, retrieved, accumulated_evidence, model_info)
|
| 1335 |
+
if answerable_response["is_answerable"]:
|
| 1336 |
+
plan["answer"] = answerable_response["answer"]
|
| 1337 |
+
return retrieved, attempt_record
|
| 1338 |
+
|
| 1339 |
+
if len(pool) == 0:
|
| 1340 |
+
break
|
| 1341 |
+
|
| 1342 |
+
return retrieved, attempt_record
|
| 1343 |
+
|
| 1344 |
+
def _execute_strategy(self, pool, plan, question):
|
| 1345 |
+
step_logs: List[str] = []
|
| 1346 |
+
|
| 1347 |
+
# Start from all chat items
|
| 1348 |
+
if self.topic_filter and len(plan.get('topics', [])) > 0:
|
| 1349 |
+
pool = pool.get_item_by_topics(plan['topics'])
|
| 1350 |
+
|
| 1351 |
+
retrieved = ChatHistory()
|
| 1352 |
+
|
| 1353 |
+
strategy = plan["strategy"]
|
| 1354 |
+
for step in strategy:
|
| 1355 |
+
method = step.get("method")
|
| 1356 |
+
if method == "keyword":
|
| 1357 |
+
kws = step.get("keywords", [])
|
| 1358 |
+
matched = keyword_search(pool, kws)
|
| 1359 |
+
step_logs.append(f"Method: keyword - matched {len(matched)}/{len(pool)} using {kws}")
|
| 1360 |
+
if len(matched) > 0:
|
| 1361 |
+
retrieved.merge_rel_sess(matched)
|
| 1362 |
+
elif method == "embedding":
|
| 1363 |
+
top_k = 50
|
| 1364 |
+
matched = embedding_search(pool, qid, top_k=top_k)
|
| 1365 |
+
step_logs.append(f"Method: embedding - top_k={top_k}, matched {len(matched)}/{len(pool)}")
|
| 1366 |
+
if len(matched) > 0:
|
| 1367 |
+
retrieved.merge_rel_sess(matched)
|
| 1368 |
+
elif method == "time_filter":
|
| 1369 |
+
if self.no_time_filter:
|
| 1370 |
+
step_logs.append(f"Method: time_filter - skipped (--no_time_filter)")
|
| 1371 |
+
continue
|
| 1372 |
+
if 'time_range' not in step or len(step['time_range']) != 2:
|
| 1373 |
+
continue
|
| 1374 |
+
if len(retrieved) > 0:
|
| 1375 |
+
retrieved = time_filter(retrieved, start_date=step['time_range'][0], end_date=step['time_range'][1])
|
| 1376 |
+
else:
|
| 1377 |
+
retrieved = time_filter(pool, start_date=step['time_range'][0], end_date=step['time_range'][1])
|
| 1378 |
+
step_logs.append(f"Method: time_filter - kept {len(retrieved)}/{len(pool)} in {step['time_range'][0]}..{step['time_range'][1]}")
|
| 1379 |
+
#if len(matched) > 0:
|
| 1380 |
+
# retrieved.merge_rel_sess(matched)
|
| 1381 |
+
else:
|
| 1382 |
+
step_logs.append(f"unknown method: {method}")
|
| 1383 |
+
|
| 1384 |
+
if len(retrieved) > 100:
|
| 1385 |
+
top_k = 100
|
| 1386 |
+
retrieved = embedding_search(retrieved, qid, top_k=top_k)
|
| 1387 |
+
step_logs.append(
|
| 1388 |
+
f"too many sess ({len(pool)}) - embedding top_k={top_k} matched {len(retrieved)}/{len(pool)}"
|
| 1389 |
+
)
|
| 1390 |
+
|
| 1391 |
+
return retrieved, step_logs
|
| 1392 |
+
|
| 1393 |
+
def merge_rel_sess(self, new_sessions: ChatHistory):
|
| 1394 |
+
# Gather all current and new sessions in a dict keyed by session_id
|
| 1395 |
+
all_sessions = {s["session_id"]: s for s in self.rel_sess.sessions}
|
| 1396 |
+
|
| 1397 |
+
# add if new
|
| 1398 |
+
for s in new_sessions.sessions:
|
| 1399 |
+
if s["session_id"] not in all_sessions:
|
| 1400 |
+
all_sessions[s["session_id"]] = s
|
| 1401 |
+
|
| 1402 |
+
# Optional: sort sessions by timestamp for consistent ordering
|
| 1403 |
+
#merged_sessions = list(all_sessions.values())
|
| 1404 |
+
#merged_sessions.sort(key=lambda x: x["timestamp"])
|
| 1405 |
+
|
| 1406 |
+
# Reconstruct raw_data for new ChatHistory
|
| 1407 |
+
merged_raw_data = {
|
| 1408 |
+
"haystack_dates": [s["session_date"] for k, s in all_sessions.items()],
|
| 1409 |
+
"haystack_session_ids": [s["session_id"] for k, s in all_sessions.items()],
|
| 1410 |
+
"haystack_sessions": [s["session"] for k, s in all_sessions.items()],
|
| 1411 |
+
}
|
| 1412 |
+
self.rel_sess = ChatHistory(merged_raw_data)
|
| 1413 |
+
|
| 1414 |
+
|
| 1415 |
+
def merge_evidence(self, new_evidence: list):
|
| 1416 |
+
self.evidence = self.evidence + new_evidence
|
| 1417 |
+
print(f"\t\t Updated evidence: {self.evidence}")
|
| 1418 |
+
|
| 1419 |
+
|
| 1420 |
+
if __name__ == "__main__":
|
| 1421 |
+
parser = argparse.ArgumentParser()
|
| 1422 |
+
parser.add_argument('--in_file', type=str, required=True)
|
| 1423 |
+
parser.add_argument('--out_file', type=str, required=True)
|
| 1424 |
+
parser.add_argument('--model_name', type=str, required=True)
|
| 1425 |
+
parser.add_argument('--top_k', type=int, required=True)
|
| 1426 |
+
parser.add_argument('--debug', action='store_true', default=False)
|
| 1427 |
+
parser.add_argument('--vllm', action='store_true', default=False)
|
| 1428 |
+
parser.add_argument('--tritonai', action='store_true', default=False,
|
| 1429 |
+
help='Use OpenAI-compatible LiteLLM proxy for non-reading LLM calls (set TRITONAI_API_KEY)')
|
| 1430 |
+
parser.add_argument('--nvidia', action='store_true', default=False,
|
| 1431 |
+
help='Use NVIDIA inference API (set NV_API_KEY)')
|
| 1432 |
+
parser.add_argument('--vllm_reading', action='store_true', default=False,
|
| 1433 |
+
help='Use vLLM only for verification reading; all other LLM calls use the proprietary API')
|
| 1434 |
+
parser.add_argument('--n_chunks', type=int, default=10,
|
| 1435 |
+
help='Number of sessions per verification-reading LLM call (default 10)')
|
| 1436 |
+
parser.add_argument('--max_loops', type=int, default=3,
|
| 1437 |
+
help='Maximum retrieval/planning loops for agent mode (default 3)')
|
| 1438 |
+
parser.add_argument('--mode', type=str, default="agent", choices=['agent', 'embed', 'keyword'])
|
| 1439 |
+
parser.add_argument('--topic_filter', type=bool, default=True)
|
| 1440 |
+
parser.add_argument('--user_profile', action=argparse.BooleanOptionalAction, default=True,
|
| 1441 |
+
help='Include user profile in prompts (default: True, use --no-user_profile to disable)')
|
| 1442 |
+
parser.add_argument('--no_semantic', action='store_true', default=False,
|
| 1443 |
+
help='Skip semantic Stage 1; run episodic Stage 2 only on all haystack sessions')
|
| 1444 |
+
parser.add_argument('--no_time_filter', action='store_true', default=False,
|
| 1445 |
+
help='Disable time_filter steps in strategy execution (can reuse plan cache)')
|
| 1446 |
+
# Two-stage memory arguments
|
| 1447 |
+
parser.add_argument('--semantic_ret_cache', type=str, default=None,
|
| 1448 |
+
help='Path to semantic-gte retrieval log (JSONL) for Stage 1')
|
| 1449 |
+
parser.add_argument('--summary_file', type=str, default=None,
|
| 1450 |
+
help='Path to all_session_summary.json for SemanticMemoryStore')
|
| 1451 |
+
parser.add_argument('--facts_file', type=str, default=None,
|
| 1452 |
+
help='Path to all_session_user_facts.json for SemanticMemoryStore')
|
| 1453 |
+
parser.add_argument('--all_sessions_file', type=str, default=None,
|
| 1454 |
+
help='Path to all_sessions.json for lazy episodic loading')
|
| 1455 |
+
parser.add_argument('--no_save_cache', action='store_true', default=False,
|
| 1456 |
+
help='Disable saving plan/reading caches to disk after the run')
|
| 1457 |
+
parser.add_argument('--hier_v2', action='store_true', default=False,
|
| 1458 |
+
help='Stage 1 produces candidates only: skip early-answer return, semantic keyword expansion, time_filter, and is_answerable shortcut')
|
| 1459 |
+
parser.add_argument('--hier_union', action='store_true', default=False,
|
| 1460 |
+
help='hier mode: union Stage-1 semantic candidates with flat-embedding top-K and run agent loop on the merged pool')
|
| 1461 |
+
parser.add_argument('--hier_union_flat_k', type=int, default=20,
|
| 1462 |
+
help='How many flat-embedding top-K IDs to union into the Stage-2 pool (default 20)')
|
| 1463 |
+
parser.add_argument('--no_early_answer', action='store_true', default=False,
|
| 1464 |
+
help='Disable Stage-1 is_answerable early-return shortcut; always proceed to Stage-2 agent loop')
|
| 1465 |
+
parser.add_argument('--answer_prompt_v2', action='store_true', default=False,
|
| 1466 |
+
help='Use the v2 answer prompt with explicit guidance for aggregation, temporal reasoning, knowledge updates, and absence cases.')
|
| 1467 |
+
args = parser.parse_args()
|
| 1468 |
+
|
| 1469 |
+
# Rebind reading cache to include n_chunks so different chunk sizes get separate caches
|
| 1470 |
+
veri_reading_log_file = os.environ['reading_cache'] + f'_nchunks{args.n_chunks}'
|
| 1471 |
+
qid2rel_sess_ids = {}
|
| 1472 |
+
if os.path.exists(veri_reading_log_file):
|
| 1473 |
+
qid2rel_sess_ids = json.load(open(veri_reading_log_file))
|
| 1474 |
+
print(f'Reading cache: {veri_reading_log_file} ({len(qid2rel_sess_ids)} cached entries)')
|
| 1475 |
+
|
| 1476 |
+
in_data = json.load(open(args.in_file))
|
| 1477 |
+
top_k = args.top_k
|
| 1478 |
+
out_file = args.out_file
|
| 1479 |
+
|
| 1480 |
+
model_info = model_zoo[args.model_name]
|
| 1481 |
+
deployment_name, api_version = model_info
|
| 1482 |
+
|
| 1483 |
+
existings = set()
|
| 1484 |
+
retrieval_metric_list = []
|
| 1485 |
+
if os.path.exists(out_file):
|
| 1486 |
+
for line in open(out_file):
|
| 1487 |
+
obj = json.loads(line)
|
| 1488 |
+
existings.add(obj['question_id'])
|
| 1489 |
+
if 'retrieval_metric' in obj:
|
| 1490 |
+
retrieval_metric_list.append(obj['retrieval_metric'])
|
| 1491 |
+
|
| 1492 |
+
out_f = open(out_file, 'a')
|
| 1493 |
+
|
| 1494 |
+
############# read meta files #####################
|
| 1495 |
+
qid2profiles = {}
|
| 1496 |
+
with open("metadata/generated_user_profile.json") as f:
|
| 1497 |
+
qid2profiles = json.load(f)
|
| 1498 |
+
sess2topic = {}
|
| 1499 |
+
with open("metadata/sessions_with_topic.json") as f:
|
| 1500 |
+
sess2topic = json.load(f)
|
| 1501 |
+
|
| 1502 |
+
# ----------------------------------------------------------------
|
| 1503 |
+
# Two-stage memory stores (optional; activated by CLI args)
|
| 1504 |
+
# ----------------------------------------------------------------
|
| 1505 |
+
semantic_store = None
|
| 1506 |
+
episodic_store = None
|
| 1507 |
+
semantic_ret_dict = None
|
| 1508 |
+
|
| 1509 |
+
if args.summary_file and args.facts_file:
|
| 1510 |
+
semantic_store = SemanticMemoryStore(args.summary_file, args.facts_file)
|
| 1511 |
+
|
| 1512 |
+
if args.all_sessions_file:
|
| 1513 |
+
episodic_store = EpisodicMemoryStore(args.all_sessions_file)
|
| 1514 |
+
|
| 1515 |
+
if args.semantic_ret_cache:
|
| 1516 |
+
print(f"Loading semantic retrieval cache from {args.semantic_ret_cache} ...")
|
| 1517 |
+
sem_ret_data = [json.loads(line) for line in open(args.semantic_ret_cache)]
|
| 1518 |
+
semantic_ret_dict = {x['question_id']: x for x in sem_ret_data}
|
| 1519 |
+
print(f" Loaded {len(semantic_ret_dict)} entries.")
|
| 1520 |
+
|
| 1521 |
+
retrieval_metric_list = []
|
| 1522 |
+
for di, entry in enumerate(in_data):
|
| 1523 |
+
item_start_time = time.time()
|
| 1524 |
+
qid, question, q_date = entry['question_id'], entry['question'], entry['question_date']
|
| 1525 |
+
q_date = entry['question_date']
|
| 1526 |
+
|
| 1527 |
+
if qid in existings:
|
| 1528 |
+
continue
|
| 1529 |
+
|
| 1530 |
+
haystack_sess_ids = entry['haystack_session_ids']
|
| 1531 |
+
haystack_topics = [sess2topic.get(sid, {}).get('category', []) for sid in haystack_sess_ids]
|
| 1532 |
+
date_lookup = dict(zip(haystack_sess_ids, entry['haystack_dates']))
|
| 1533 |
+
topic_lookup = dict(zip(haystack_sess_ids, haystack_topics))
|
| 1534 |
+
|
| 1535 |
+
# Build ChatHistory: lazily from episodic store (two-stage) or from raw data (legacy)
|
| 1536 |
+
if episodic_store is not None:
|
| 1537 |
+
# Two-stage mode: start with full haystack loaded from episodic store
|
| 1538 |
+
# (Stage 1 will narrow this down before Stage 2 runs)
|
| 1539 |
+
raw_sessions = episodic_store.get_raw_sessions(
|
| 1540 |
+
haystack_sess_ids, date_lookup, topic_lookup
|
| 1541 |
+
)
|
| 1542 |
+
chat_history = ChatHistory(sessions=raw_sessions)
|
| 1543 |
+
else:
|
| 1544 |
+
chat_history = ChatHistory({
|
| 1545 |
+
"haystack_dates": entry['haystack_dates'],
|
| 1546 |
+
"haystack_session_ids": entry['haystack_session_ids'],
|
| 1547 |
+
"haystack_sessions": entry['haystack_sessions'],
|
| 1548 |
+
"haystack_topics": haystack_topics,
|
| 1549 |
+
})
|
| 1550 |
+
|
| 1551 |
+
topic_set = set()
|
| 1552 |
+
for ht in haystack_topics:
|
| 1553 |
+
topic_set.update(ht)
|
| 1554 |
+
|
| 1555 |
+
if args.user_profile:
|
| 1556 |
+
# user profile
|
| 1557 |
+
temp_qid = qid
|
| 1558 |
+
if '_q_' in qid:
|
| 1559 |
+
temp_qid = qid.split("_q_")[0]
|
| 1560 |
+
user_profile = qid2profiles[temp_qid]
|
| 1561 |
+
agent = RetrievalAgent(
|
| 1562 |
+
chat_history,
|
| 1563 |
+
list(topic_set),
|
| 1564 |
+
user_profile=user_profile,
|
| 1565 |
+
debug=args.debug,
|
| 1566 |
+
vllm=args.vllm,
|
| 1567 |
+
vllm_reading=args.vllm_reading,
|
| 1568 |
+
tritonai=args.tritonai,
|
| 1569 |
+
nvidia=args.nvidia,
|
| 1570 |
+
n_chunks=args.n_chunks,
|
| 1571 |
+
topic_filter=args.topic_filter,
|
| 1572 |
+
no_time_filter=args.no_time_filter,
|
| 1573 |
+
semantic_store=None if args.no_semantic else semantic_store,
|
| 1574 |
+
episodic_store=episodic_store,
|
| 1575 |
+
hier_v2=args.hier_v2,
|
| 1576 |
+
hier_union=args.hier_union,
|
| 1577 |
+
hier_union_flat_k=args.hier_union_flat_k,
|
| 1578 |
+
no_early_answer=args.no_early_answer,
|
| 1579 |
+
)
|
| 1580 |
+
else:
|
| 1581 |
+
agent = RetrievalAgent(
|
| 1582 |
+
chat_history,
|
| 1583 |
+
list(topic_set),
|
| 1584 |
+
debug=args.debug,
|
| 1585 |
+
vllm=args.vllm,
|
| 1586 |
+
vllm_reading=args.vllm_reading,
|
| 1587 |
+
tritonai=args.tritonai,
|
| 1588 |
+
nvidia=args.nvidia,
|
| 1589 |
+
n_chunks=args.n_chunks,
|
| 1590 |
+
topic_filter=args.topic_filter,
|
| 1591 |
+
no_time_filter=args.no_time_filter,
|
| 1592 |
+
semantic_store=None if args.no_semantic else semantic_store,
|
| 1593 |
+
episodic_store=episodic_store,
|
| 1594 |
+
hier_v2=args.hier_v2,
|
| 1595 |
+
hier_union=args.hier_union,
|
| 1596 |
+
hier_union_flat_k=args.hier_union_flat_k,
|
| 1597 |
+
no_early_answer=args.no_early_answer,
|
| 1598 |
+
)
|
| 1599 |
+
|
| 1600 |
+
try:
|
| 1601 |
+
if args.mode == 'embed':
|
| 1602 |
+
final_sess = embedding_search(chat_history, qid, top_k=top_k)
|
| 1603 |
+
attempt_record = [{"plan": {"answer": "none", "reason": "embedding retrieval only"}}]
|
| 1604 |
+
elif args.mode == 'keyword':
|
| 1605 |
+
keywords = generate_keywords(question, deployment_name, api_version,
|
| 1606 |
+
debug=args.debug, vllm=args.vllm,
|
| 1607 |
+
tritonai=args.tritonai, nvidia=args.nvidia)
|
| 1608 |
+
final_sess = keyword_search(chat_history, keywords=keywords)
|
| 1609 |
+
attempt_record = [{"plan": {"answer": "none", "reason": "keyword retrieval only"}}]
|
| 1610 |
+
else: # agent
|
| 1611 |
+
final_sess, attempt_record = agent.run(
|
| 1612 |
+
qid, question, q_date, top_k, model_info,
|
| 1613 |
+
max_loops=args.max_loops,
|
| 1614 |
+
semantic_ret_dict=semantic_ret_dict,
|
| 1615 |
+
haystack_sess_ids=haystack_sess_ids,
|
| 1616 |
+
date_lookup=date_lookup,
|
| 1617 |
+
topic_lookup=topic_lookup,
|
| 1618 |
+
)
|
| 1619 |
+
|
| 1620 |
+
if len(attempt_record) == 1 and "answer" in attempt_record[0]["plan"] and not ("none" in attempt_record[0]["plan"]["answer"].lower()):
|
| 1621 |
+
answer = attempt_record[0]["plan"]["answer"]
|
| 1622 |
+
token_budget = agent.get_token_budget() if args.mode == 'agent' else {}
|
| 1623 |
+
wall_time_sec = time.time() - item_start_time
|
| 1624 |
+
|
| 1625 |
+
print(json.dumps({"q_idx": di, 'question_id': qid, 'question': entry['question'],
|
| 1626 |
+
'answer': answer, 'n_retrieved': len(final_sess),
|
| 1627 |
+
'wall_time_sec': round(wall_time_sec, 3)}, indent=4), flush=True)
|
| 1628 |
+
print(json.dumps({"q_idx": di, 'question_id': qid,
|
| 1629 |
+
'hypothesis': answer,
|
| 1630 |
+
"attempt_record": attempt_record,
|
| 1631 |
+
"token_budget": token_budget,
|
| 1632 |
+
"wall_time_sec": wall_time_sec}), file=out_f, flush=True)
|
| 1633 |
+
else:
|
| 1634 |
+
if len(final_sess) > top_k and retrieved_log_file is not None:
|
| 1635 |
+
final_top_k_sess, _ = filter_out_by_embedding(final_sess, qid=qid, top_k=top_k)
|
| 1636 |
+
retrieved_str = final_top_k_sess.to_prompt(granularity="session", _format="json")
|
| 1637 |
+
else:
|
| 1638 |
+
retrieved_str = final_sess.to_prompt(granularity="session", _format="json")
|
| 1639 |
+
|
| 1640 |
+
if args.answer_prompt_v2:
|
| 1641 |
+
answer_prompt_template = (
|
| 1642 |
+
"You are answering a question using a list of chat-session transcripts between the user and an assistant.\n"
|
| 1643 |
+
"\n"
|
| 1644 |
+
"How to answer:\n"
|
| 1645 |
+
"1. Scan ALL retrieved sessions in chronological order. The SESSION DATE on each transcript is when that conversation occurred. The Current Date below is when the question was asked, not when events happened.\n"
|
| 1646 |
+
"2. Identify every session containing a candidate fact. If sessions conflict, prefer the most RECENT session that addresses the same fact (knowledge update).\n"
|
| 1647 |
+
"3. For aggregation questions ('how many', 'list all', 'between X and Y'), enumerate matches across ALL relevant sessions; do not stop at the first.\n"
|
| 1648 |
+
"4. For temporal queries ('last Friday', 'two weeks ago'), resolve the relative date against the SESSION DATE of the session that uses that phrase, not the Current Date.\n"
|
| 1649 |
+
"5. If the retrieved sessions do NOT contain the answer, reply exactly 'Insufficient information in retrieved sessions.' Do not fabricate.\n"
|
| 1650 |
+
"6. Be terse: state the direct answer first, then one short sentence citing the session date(s) you relied on.\n"
|
| 1651 |
+
"\n"
|
| 1652 |
+
"Chat history sessions:\n"
|
| 1653 |
+
"\n"
|
| 1654 |
+
"{}\n"
|
| 1655 |
+
"\n"
|
| 1656 |
+
"Current Date: {}\n"
|
| 1657 |
+
"Question: {}\n"
|
| 1658 |
+
"Answer:"
|
| 1659 |
+
)
|
| 1660 |
+
else:
|
| 1661 |
+
answer_prompt_template = "I will give you several chat history sessions between you and a user. Please answer the question given the information.\n\n\nChat history sessions:\n\n{}\n\nCurrent Date: {}\nQuestion: {}\nAnswer:"
|
| 1662 |
+
answer_prompt = answer_prompt_template.format(retrieved_str, entry['question_date'], entry['question'])
|
| 1663 |
+
|
| 1664 |
+
completion = llm_call(
|
| 1665 |
+
deployment_name,
|
| 1666 |
+
api_version,
|
| 1667 |
+
answer_prompt,
|
| 1668 |
+
debug=args.debug,
|
| 1669 |
+
vllm=args.vllm,
|
| 1670 |
+
tritonai=args.tritonai,
|
| 1671 |
+
nvidia=args.nvidia,
|
| 1672 |
+
)
|
| 1673 |
+
answer = (completion.choices[0].message.content or "").strip()
|
| 1674 |
+
|
| 1675 |
+
if args.mode == 'agent':
|
| 1676 |
+
agent._track_usage('final_answer', completion)
|
| 1677 |
+
token_budget = agent.get_token_budget() if args.mode == 'agent' else {}
|
| 1678 |
+
|
| 1679 |
+
retrieval_metric = {}
|
| 1680 |
+
if len(final_sess) > 0 and retrieved_log_file is not None:
|
| 1681 |
+
sess_sorted = embedding_search(final_sess, qid, top_k=20)
|
| 1682 |
+
sess_id_sorted = sess_sorted.get_session_ids()
|
| 1683 |
+
|
| 1684 |
+
for topk in [5, 10, 20, 30]:
|
| 1685 |
+
recall_any, recall_all = evaluate_retrieval(sess_id_sorted[:topk], entry['answer_session_ids'])
|
| 1686 |
+
retrieval_metric.update({
|
| 1687 |
+
'recall_any@{}'.format(topk): recall_any,
|
| 1688 |
+
'recall_all@{}'.format(topk): recall_all
|
| 1689 |
+
})
|
| 1690 |
+
retrieval_metric_list.append(retrieval_metric)
|
| 1691 |
+
print_average_metrics(retrieval_metric_list)
|
| 1692 |
+
|
| 1693 |
+
print(json.dumps({"q_idx": di, 'n_prompt_tok': completion.usage.prompt_tokens,
|
| 1694 |
+
'n_completion_tok': completion.usage.completion_tokens,
|
| 1695 |
+
'hypothesis': answer,
|
| 1696 |
+
'wall_time_sec': round(time.time() - item_start_time, 3)}), flush=True)
|
| 1697 |
+
print(json.dumps({"q_idx": di, 'question_id': qid,
|
| 1698 |
+
'hypothesis': answer,
|
| 1699 |
+
'n_prompt_tok': completion.usage.prompt_tokens,
|
| 1700 |
+
'n_completion_tok': completion.usage.completion_tokens,
|
| 1701 |
+
"attempt_record": attempt_record,
|
| 1702 |
+
"retrieved_sess_ids": final_sess.get_session_ids(),
|
| 1703 |
+
"retrieval_metric": retrieval_metric,
|
| 1704 |
+
"token_budget": token_budget,
|
| 1705 |
+
"wall_time_sec": time.time() - item_start_time}), file=out_f, flush=True)
|
| 1706 |
+
except Exception as e:
|
| 1707 |
+
print(f"[ERROR] q_idx={di} qid={qid} failed: {e}", flush=True)
|
| 1708 |
+
continue
|
| 1709 |
+
|
| 1710 |
+
############# save cache ##########################
|
| 1711 |
+
|
| 1712 |
+
if not args.no_save_cache:
|
| 1713 |
+
with open(plan_cache_file, "w") as fw:
|
| 1714 |
+
json.dump(qid2plan, fw, indent=2)
|
| 1715 |
+
|
| 1716 |
+
with open(veri_reading_log_file, "w") as fw:
|
| 1717 |
+
json.dump(qid2rel_sess_ids, fw, indent=2)
|
memory/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from memory.episodic_store import EpisodicMemoryStore
|
| 2 |
+
from memory.semantic_store import SemanticMemoryStore
|
memory/episodic_store.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
episodic_store.py
|
| 3 |
+
|
| 4 |
+
Lazy loader for all_sessions.json. Keeps the full raw-turn data in memory
|
| 5 |
+
and returns sessions on demand by session ID, avoiding loading all sessions
|
| 6 |
+
into ChatHistory upfront.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from typing import Dict, List, Optional
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class EpisodicMemoryStore:
|
| 15 |
+
def __init__(self, all_sessions_path: str):
|
| 16 |
+
print(f"[EpisodicMemoryStore] Loading {all_sessions_path} ...")
|
| 17 |
+
with open(all_sessions_path) as f:
|
| 18 |
+
self._data: Dict[str, List] = json.load(f)
|
| 19 |
+
print(f"[EpisodicMemoryStore] Loaded {len(self._data)} sessions.")
|
| 20 |
+
|
| 21 |
+
@staticmethod
|
| 22 |
+
def _parse_date(date_str: str) -> datetime:
|
| 23 |
+
"""Convert '2023/04/10 (Mon) 17:50' to datetime."""
|
| 24 |
+
date_part = date_str.split('(')[0].strip()
|
| 25 |
+
time_part = date_str.split(')')[-1].strip()
|
| 26 |
+
return datetime.strptime(date_part + time_part, "%Y/%m/%d%H:%M")
|
| 27 |
+
|
| 28 |
+
def get_raw_sessions(
|
| 29 |
+
self,
|
| 30 |
+
sess_ids: List[str],
|
| 31 |
+
date_lookup: Dict[str, str],
|
| 32 |
+
topic_lookup: Optional[Dict[str, List[str]]] = None,
|
| 33 |
+
) -> List[dict]:
|
| 34 |
+
"""
|
| 35 |
+
Return a list of session dicts compatible with ChatHistory(sessions=...).
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
sess_ids: Ordered list of session IDs to load.
|
| 39 |
+
date_lookup: Mapping sess_id -> date string (e.g. '2023/04/10 (Mon) 17:50').
|
| 40 |
+
topic_lookup: Optional mapping sess_id -> list of topic strings.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
List of session dicts, each with keys:
|
| 44 |
+
session_id, session_date, session, topic, timestamp
|
| 45 |
+
"""
|
| 46 |
+
sessions = []
|
| 47 |
+
for sid in sess_ids:
|
| 48 |
+
if sid not in self._data:
|
| 49 |
+
continue
|
| 50 |
+
date_str = date_lookup.get(sid, "")
|
| 51 |
+
try:
|
| 52 |
+
ts = self._parse_date(date_str) if date_str else datetime.min
|
| 53 |
+
except Exception:
|
| 54 |
+
ts = datetime.min
|
| 55 |
+
sessions.append({
|
| 56 |
+
"session_id": sid,
|
| 57 |
+
"session_date": date_str,
|
| 58 |
+
"session": self._data[sid],
|
| 59 |
+
"topic": topic_lookup.get(sid, []) if topic_lookup else [],
|
| 60 |
+
"timestamp": ts,
|
| 61 |
+
})
|
| 62 |
+
return sessions
|
memory/semantic_store.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
semantic_store.py
|
| 3 |
+
|
| 4 |
+
Wrapper around all_session_summary.json and all_session_user_facts.json.
|
| 5 |
+
Provides:
|
| 6 |
+
- keyword_search(): find sessions whose semantic text contains given keywords
|
| 7 |
+
- to_prompt(): format semantic context for LLM consumption
|
| 8 |
+
- get_text(): return raw semantic text for a session (for embedding/search)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
from typing import Dict, List, Optional
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SemanticMemoryStore:
|
| 16 |
+
def __init__(self, summary_path: str, facts_path: str):
|
| 17 |
+
print(f"[SemanticMemoryStore] Loading {summary_path} ...")
|
| 18 |
+
with open(summary_path) as f:
|
| 19 |
+
self._summaries: Dict[str, dict] = json.load(f)
|
| 20 |
+
|
| 21 |
+
print(f"[SemanticMemoryStore] Loading {facts_path} ...")
|
| 22 |
+
with open(facts_path) as f:
|
| 23 |
+
self._facts: Dict[str, list] = json.load(f)
|
| 24 |
+
|
| 25 |
+
print(f"[SemanticMemoryStore] Loaded {len(self._summaries)} summaries, "
|
| 26 |
+
f"{len(self._facts)} fact entries.")
|
| 27 |
+
|
| 28 |
+
def get_summary(self, sess_id: str) -> str:
|
| 29 |
+
"""Return the session-level summary string, or empty string."""
|
| 30 |
+
entry = self._summaries.get(sess_id, {})
|
| 31 |
+
return entry.get("session_summary", "").strip()
|
| 32 |
+
|
| 33 |
+
def get_facts_text(self, sess_id: str) -> str:
|
| 34 |
+
"""Return user facts as a single joined string, or empty string."""
|
| 35 |
+
fact_list = self._facts.get(sess_id, [])
|
| 36 |
+
if not fact_list:
|
| 37 |
+
return ""
|
| 38 |
+
return " ".join(
|
| 39 |
+
f["user-info"] for f in fact_list
|
| 40 |
+
if isinstance(f, dict) and f.get("user-info")
|
| 41 |
+
).strip()
|
| 42 |
+
|
| 43 |
+
def get_text(self, sess_id: str) -> str:
|
| 44 |
+
"""Return summary + facts combined (for keyword search or display)."""
|
| 45 |
+
parts = [self.get_summary(sess_id), self.get_facts_text(sess_id)]
|
| 46 |
+
return " ".join(p for p in parts if p)
|
| 47 |
+
|
| 48 |
+
def keyword_search(self, keywords: List[str], haystack_sess_ids: List[str]) -> List[str]:
|
| 49 |
+
"""
|
| 50 |
+
Search semantic text (summary + facts) of the given sessions for any keyword.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
List of matching session IDs (preserving haystack order).
|
| 54 |
+
"""
|
| 55 |
+
matched = []
|
| 56 |
+
kws_lower = [kw.lower() for kw in keywords if kw]
|
| 57 |
+
for sid in haystack_sess_ids:
|
| 58 |
+
text = self.get_text(sid).lower()
|
| 59 |
+
if any(kw in text for kw in kws_lower):
|
| 60 |
+
matched.append(sid)
|
| 61 |
+
return matched
|
| 62 |
+
|
| 63 |
+
def to_prompt(self, sess_ids: List[str], date_lookup: Optional[Dict[str, str]] = None) -> str:
|
| 64 |
+
"""
|
| 65 |
+
Format semantic context for these sessions as a prompt string.
|
| 66 |
+
|
| 67 |
+
Each session block:
|
| 68 |
+
Session Date: <date>
|
| 69 |
+
Summary: <session_summary>
|
| 70 |
+
User Facts: <fact1>; <fact2>; ...
|
| 71 |
+
"""
|
| 72 |
+
lines = []
|
| 73 |
+
for sid in sess_ids:
|
| 74 |
+
date_str = date_lookup.get(sid, "") if date_lookup else ""
|
| 75 |
+
summary = self.get_summary(sid)
|
| 76 |
+
facts_text = self.get_facts_text(sid)
|
| 77 |
+
|
| 78 |
+
block = f"Session ID: {sid}"
|
| 79 |
+
if date_str:
|
| 80 |
+
block += f"\nSession Date: {date_str}"
|
| 81 |
+
if summary:
|
| 82 |
+
block += f"\nSummary: {summary}"
|
| 83 |
+
if facts_text:
|
| 84 |
+
block += f"\nUser Facts: {facts_text}"
|
| 85 |
+
lines.append(block)
|
| 86 |
+
|
| 87 |
+
return "\n\n".join(lines)
|
model_zoo.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_zoo = {
|
| 2 |
+
# OpenAI / Azure-hosted models (deployment string, api_version)
|
| 3 |
+
'gpt-5': ("gpt-5_2025-08-07", "2024-12-01-preview"),
|
| 4 |
+
"gpt-4.1-azure": ("gpt-4.1_2025-04-14", "2025-04-01-preview"),
|
| 5 |
+
'gpt-4o': ('gpt-4o_2024-11-20', '2024-10-21'),
|
| 6 |
+
'gpt-4o-mini': ("gpt-4o-mini", ""),
|
| 7 |
+
'gpt-5-openai': ("gpt-5", ""),
|
| 8 |
+
'gpt-5-mini-openai': ("gpt-5-mini", ""),
|
| 9 |
+
|
| 10 |
+
# vLLM-hosted models (OpenAI-compatible server)
|
| 11 |
+
'Qwen3-30B-A3B-Instruct-2507': ("Qwen/Qwen3-30B-A3B-Instruct-2507", ""),
|
| 12 |
+
'Qwen3-VL-30B-A3B-Instruct': ("Qwen3-VL-30B-A3B-Instruct", ""),
|
| 13 |
+
|
| 14 |
+
# Anthropic models via direct Anthropic API (uses ANTHROPIC_API_KEY)
|
| 15 |
+
'claude-opus-4-6': ("claude-opus-4-6", ""),
|
| 16 |
+
'claude-sonnet-4-6': ("claude-sonnet-4-6", ""),
|
| 17 |
+
|
| 18 |
+
# Anthropic / DeepSeek via an OpenAI-compatible LiteLLM proxy
|
| 19 |
+
# (uses LITELLM_API_KEY; selected by main.py's --tritonai flag)
|
| 20 |
+
'claude-opus-4-6-tritonai': ("us.anthropic.claude-opus-4-6-v1", ""),
|
| 21 |
+
'claude-sonnet-4-6-tritonai': ("us.anthropic.claude-sonnet-4-6-v1", ""),
|
| 22 |
+
'deepseek-r1-tritonai': ("us.deepseek.r1-v1:0", ""),
|
| 23 |
+
|
| 24 |
+
# Models served via an OpenAI-compatible inference API (uses NV_API_KEY)
|
| 25 |
+
'gpt-5.1': ("openai/openai/gpt-5.1", ""),
|
| 26 |
+
'gpt-5.2': ("openai/openai/gpt-5.2", ""),
|
| 27 |
+
'gpt-5.5': ("openai/openai/gpt-5.5", ""),
|
| 28 |
+
'gpt-4.1': ("us/azure/openai/gpt-4.1", ""),
|
| 29 |
+
'Qwen3.5-397B-A17B': ("nvidia/qwen/qwen3-5-397b-a17b", ""),
|
| 30 |
+
'Kimi-K2.6': ("nvidia/moonshotai/kimi-k2.6", ""),
|
| 31 |
+
}
|
prompts/agentic_retrieval_prompt.txt
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You are an intelligent assistant for memory retrieval.
|
| 2 |
+
|
| 3 |
+
Your goal is to **design and refine retrieval strategies** for answering a user’s query by leveraging the user’s memory resources (profile, topic, chat history, and other contextual knowledge).
|
| 4 |
+
|
| 5 |
+
### Core Principles
|
| 6 |
+
|
| 7 |
+
1. Decision-first retrieval: Decide whether retrieval is necessary.
|
| 8 |
+
|
| 9 |
+
* If the query can be answered **directly from the user’s profile**, retrieval is unnecessary.
|
| 10 |
+
* Otherwise, retrieval strategies must be proposed.
|
| 11 |
+
* If there is some useful information in user's profile, keep the information and continue to retrieve.
|
| 12 |
+
|
| 13 |
+
2. Identify relevant topics:
|
| 14 |
+
|
| 15 |
+
* Given the topic list from user's chat history, identify topics related to the query.
|
| 16 |
+
* The topics will be used to narrow down search space. Be inclusive but not too general.
|
| 17 |
+
|
| 18 |
+
3. Multi-method retrieval: Multiple retrieval methods may be combined:
|
| 19 |
+
|
| 20 |
+
* **Keyword-based retrieval** for unique names, places, or identifiers.
|
| 21 |
+
* **Embedding-based semantic search** when the query is vague, abstract, or conversational.
|
| 22 |
+
* **Time-based filtering** ONLY when the query references dates, ranges, or relative temporal expressions (e.g., “last week,” “yesterday”), use the given `query_date` to resolve them into precise ISO 8601 ranges.
|
| 23 |
+
|
| 24 |
+
4. Loop-aware evidence collection: Retrieval may occur in **multiple iterations (loops)**. At each loop:
|
| 25 |
+
|
| 26 |
+
* Collect **evidence** from:
|
| 27 |
+
* User profile (static attributes like age, location, job, preferences)
|
| 28 |
+
* Topics (higher-level semantic indexing)
|
| 29 |
+
* Raw chat sessions (extract **key clue sentences**)
|
| 30 |
+
* Incorporate **previous retrieval attempts** (if provided) and refine strategy. If your previous attempt failed to retreive relevant sessions or evidneces, try without filter methods such as topic and time filters.
|
| 31 |
+
|
| 32 |
+
5. Consistent JSON output: All outputs must follow the unified schema to enable downstream automation.
|
| 33 |
+
|
| 34 |
+
---
|
| 35 |
+
|
| 36 |
+
## JSON Output Schema
|
| 37 |
+
|
| 38 |
+
```json
|
| 39 |
+
{
|
| 40 |
+
"query": "<original user query>",
|
| 41 |
+
"query_date": "<ISO 8601 date string>",
|
| 42 |
+
"loop_iteration": <integer, starts at 1>,
|
| 43 |
+
"retrieval_decision": "<none | retrieval_required>",
|
| 44 |
+
"answer": "<none | answer if possilbe>",
|
| 45 |
+
"topics": ["<list of relevant topics>"],
|
| 46 |
+
"strategy": [
|
| 47 |
+
{
|
| 48 |
+
"method": "<keyword | time_filter | embedding>",
|
| 49 |
+
"conditions": "<why this method is chosen>",
|
| 50 |
+
"keywords": ["<list of keywords>"],
|
| 51 |
+
"time_range": ["<start_date>", "<end_date>"]
|
| 52 |
+
}
|
| 53 |
+
],
|
| 54 |
+
"evidence": {
|
| 55 |
+
"profile": ["<relevant profile snippets>"],
|
| 56 |
+
"chat_clues": ["<list of key sentences extracted from chat history>"]
|
| 57 |
+
},
|
| 58 |
+
"previous_attempts": [
|
| 59 |
+
{
|
| 60 |
+
"loop_iteration": <integer>,
|
| 61 |
+
"strategy": [ ... ],
|
| 62 |
+
"evidence": { ... },
|
| 63 |
+
"outcome": "<insufficient | useful | final_answer_ready>"
|
| 64 |
+
}
|
| 65 |
+
]
|
| 66 |
+
}
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
---
|
| 70 |
+
|
| 71 |
+
## Examples
|
| 72 |
+
|
| 73 |
+
### Example 1: No retrieval needed (profile sufficient)
|
| 74 |
+
|
| 75 |
+
```json
|
| 76 |
+
{
|
| 77 |
+
"query": "Where do I live?",
|
| 78 |
+
"query_date": "2025-08-29",
|
| 79 |
+
"loop_iteration": 1,
|
| 80 |
+
"retrieval_decision": "none",
|
| 81 |
+
"answer": "San Diego, California",
|
| 82 |
+
"topics": [],
|
| 83 |
+
"strategy": [],
|
| 84 |
+
"evidence": {
|
| 85 |
+
"profile": ["User lives in San Diego, California."],
|
| 86 |
+
"chat_clues": []
|
| 87 |
+
},
|
| 88 |
+
"previous_attempts": []
|
| 89 |
+
}
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
---
|
| 93 |
+
|
| 94 |
+
### Example 2: Keyword strategy
|
| 95 |
+
|
| 96 |
+
```json
|
| 97 |
+
{
|
| 98 |
+
"query": "Who did I go to the Grand Canyon with?",
|
| 99 |
+
"query_date": "2025-08-29",
|
| 100 |
+
"loop_iteration": 1,
|
| 101 |
+
"retrieval_decision": "retrieval_required",
|
| 102 |
+
"answer": "none",
|
| 103 |
+
"topics": ["Travel & Transportation", "Family & Relationships", "Personal Development"],
|
| 104 |
+
"strategy": [
|
| 105 |
+
{
|
| 106 |
+
"method": "keyword",
|
| 107 |
+
"conditions": "Query contains specific event and place name.",
|
| 108 |
+
"keywords": ["Grand Canyon"]
|
| 109 |
+
}
|
| 110 |
+
],
|
| 111 |
+
"evidence": {
|
| 112 |
+
"profile": [],
|
| 113 |
+
"chat_clues": []
|
| 114 |
+
},
|
| 115 |
+
"previous_attempts": []
|
| 116 |
+
}
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
---
|
| 120 |
+
|
| 121 |
+
### Example 3: Time-filter strategy with relative date
|
| 122 |
+
|
| 123 |
+
Query: How many miles have I hiked in the past two weeks?
|
| 124 |
+
Query date: 2023/05/22 (Mon) 23:31
|
| 125 |
+
|
| 126 |
+
```json
|
| 127 |
+
{
|
| 128 |
+
"query": "How many miles have I hiked in the past two weeks?",
|
| 129 |
+
"query_date": "2023-05-22 (Mon) 23:31",
|
| 130 |
+
"loop_iteration": 1,
|
| 131 |
+
"retrieval_decision": "retrieval_required",
|
| 132 |
+
"answer": "none",
|
| 133 |
+
"topics": ["Sports & Fitness", "Health & Wellness", "Travel & Transportation", "Personal Development"],
|
| 134 |
+
"strategy": [
|
| 135 |
+
{
|
| 136 |
+
"method": "time_filter",
|
| 137 |
+
"conditions": "Temporal phrase 'past two weeks' resolved using query_date.",
|
| 138 |
+
"time_range": ["2023-05-08", "2023-05-22"]
|
| 139 |
+
}
|
| 140 |
+
],
|
| 141 |
+
"evidence": {
|
| 142 |
+
"profile": [],
|
| 143 |
+
"chat_clues": []
|
| 144 |
+
},
|
| 145 |
+
"previous_attempts": [],
|
| 146 |
+
}
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
---
|
| 150 |
+
|
| 151 |
+
### Example 4: Keyword + Embedding
|
| 152 |
+
|
| 153 |
+
```json
|
| 154 |
+
{
|
| 155 |
+
"query": "What did my doctor say about my back pain treatment options?",
|
| 156 |
+
"query_date": "2025-08-29",
|
| 157 |
+
"loop_iteration": 1,
|
| 158 |
+
"retrieval_decision": "retrieval_required",
|
| 159 |
+
"answer": "none",
|
| 160 |
+
"topics": ["Health & Wellness", "Work & Career"],
|
| 161 |
+
"strategy": [
|
| 162 |
+
{
|
| 163 |
+
"method": "keyword",
|
| 164 |
+
"conditions": "Query contains specific medical terms that should be matched exactly.",
|
| 165 |
+
"keywords": ["doctor", "back pain", "treatment"]
|
| 166 |
+
},
|
| 167 |
+
{
|
| 168 |
+
"method": "embedding",
|
| 169 |
+
"conditions": "Query involves medical advice and recommendations which may be expressed in various conversational ways in chat history."
|
| 170 |
+
}
|
| 171 |
+
],
|
| 172 |
+
"evidence": {
|
| 173 |
+
"profile": [],
|
| 174 |
+
"chat_clues": []
|
| 175 |
+
},
|
| 176 |
+
"previous_attempts": []
|
| 177 |
+
}
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
### Example 5: Multi-method strategy with loop refinement
|
| 181 |
+
|
| 182 |
+
Loop 1 outcome was insufficient, Loop 2 continues with broader retrieval.
|
| 183 |
+
|
| 184 |
+
```json
|
| 185 |
+
{
|
| 186 |
+
"query": "What were we discussing about public issues?",
|
| 187 |
+
"query_date": "2025-08-29",
|
| 188 |
+
"loop_iteration": 2,
|
| 189 |
+
"retrieval_decision": "retrieval_required",
|
| 190 |
+
"answer": "none",
|
| 191 |
+
"topics": ["Government & Politics", "Environment & Sustainability", "Legal"],
|
| 192 |
+
"strategy": [
|
| 193 |
+
{
|
| 194 |
+
"method": "embedding",
|
| 195 |
+
"conditions": "To capture abstract conversational references."
|
| 196 |
+
}
|
| 197 |
+
],
|
| 198 |
+
"evidence": {
|
| 199 |
+
"profile": [],
|
| 200 |
+
"chat_clues": [
|
| 201 |
+
"User debated about climate policy impacts.",
|
| 202 |
+
"Conversation on local housing regulations was tagged as 'public issues'."
|
| 203 |
+
]
|
| 204 |
+
},
|
| 205 |
+
"previous_attempts": [
|
| 206 |
+
{
|
| 207 |
+
"loop_iteration": 1,
|
| 208 |
+
"strategy": [
|
| 209 |
+
{"method": "keyword", "conditions": "Query asks about 'public issue'", "keywords": ["public issue", "public issues", "issue"]}
|
| 210 |
+
],
|
| 211 |
+
"evidence": {
|
| 212 |
+
"profile": [],
|
| 213 |
+
"chat_clues": []
|
| 214 |
+
},
|
| 215 |
+
"outcome": "insufficient"
|
| 216 |
+
}
|
| 217 |
+
]
|
| 218 |
+
}
|
| 219 |
+
```
|
| 220 |
+
|
| 221 |
+
Use clear reasoning for how each method applies (or not) to the specific query. Be concise and precise.
|
| 222 |
+
|
| 223 |
+
**Do not include strategies in the strategy list unless they are needed for this query. Omit unused methods.**
|
| 224 |
+
**Always use the provided `query_date` to resolve any relative dates in the query.**
|
| 225 |
+
**Strictly follow the given JSON output format. Return only **one** JSON output.**
|
| 226 |
+
**Use time-based filtering ONLY when the query references dates, ranges, or relative temporal expressions (e.g., "last week," "yesterday," "in last 3 month")
|
prompts/agentic_retrieval_prompt_wo_profile.txt
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You are an intelligent assistant for memory retrieval.
|
| 2 |
+
|
| 3 |
+
Your goal is to **design and refine retrieval strategies** for answering a user’s query by leveraging the user’s memory resources (topic, chat history, and other contextual knowledge).
|
| 4 |
+
|
| 5 |
+
### Core Principles
|
| 6 |
+
|
| 7 |
+
1. Decision-first retrieval: Decide whether retrieval is necessary.
|
| 8 |
+
|
| 9 |
+
* If the query can be answered **directly**, retrieval is unnecessary.
|
| 10 |
+
* Otherwise, retrieval strategies must be proposed.
|
| 11 |
+
|
| 12 |
+
2. Identify relevant topics:
|
| 13 |
+
|
| 14 |
+
* Given the topic list from user's chat history, identify topics related to the query.
|
| 15 |
+
* The topics will be used to narrow down search space. Be inclusive but not too general.
|
| 16 |
+
|
| 17 |
+
3. Multi-method retrieval: Multiple retrieval methods may be combined:
|
| 18 |
+
|
| 19 |
+
* **Keyword-based retrieval** for unique names, places, or identifiers.
|
| 20 |
+
* **Embedding-based semantic search** when the query is vague, abstract, or conversational.
|
| 21 |
+
* **Time-based filtering** ONLY when the query references dates, ranges, or relative temporal expressions (e.g., “last week,” “yesterday”), use the given `query_date` to resolve them into precise ISO 8601 ranges.
|
| 22 |
+
|
| 23 |
+
4. Loop-aware evidence collection: Retrieval may occur in **multiple iterations (loops)**. At each loop:
|
| 24 |
+
|
| 25 |
+
* Collect **evidence** from:
|
| 26 |
+
* Topics (higher-level semantic indexing)
|
| 27 |
+
* Raw chat sessions (extract **key clue sentences**)
|
| 28 |
+
* Incorporate **previous retrieval attempts** (if provided) and refine strategy. If your previous attempt failed to retreive relevant sessions or evidneces, try without filter methods such as topic and time filters.
|
| 29 |
+
|
| 30 |
+
5. Consistent JSON output: All outputs must follow the unified schema to enable downstream automation.
|
| 31 |
+
|
| 32 |
+
---
|
| 33 |
+
|
| 34 |
+
## JSON Output Schema
|
| 35 |
+
|
| 36 |
+
```json
|
| 37 |
+
{
|
| 38 |
+
"query": "<original user query>",
|
| 39 |
+
"query_date": "<ISO 8601 date string>",
|
| 40 |
+
"loop_iteration": <integer, starts at 1>,
|
| 41 |
+
"retrieval_decision": "<none | retrieval_required>",
|
| 42 |
+
"answer": "<none | answer if possilbe>",
|
| 43 |
+
"topics": ["<list of relevant topics>"],
|
| 44 |
+
"strategy": [
|
| 45 |
+
{
|
| 46 |
+
"method": "<keyword | time_filter | embedding>",
|
| 47 |
+
"conditions": "<why this method is chosen>",
|
| 48 |
+
"keywords": ["<list of keywords>"],
|
| 49 |
+
"time_range": ["<start_date>", "<end_date>"]
|
| 50 |
+
}
|
| 51 |
+
],
|
| 52 |
+
"evidence": {
|
| 53 |
+
"profile": [],
|
| 54 |
+
"chat_clues": ["<list of key sentences extracted from chat history>"]
|
| 55 |
+
},
|
| 56 |
+
"previous_attempts": [
|
| 57 |
+
{
|
| 58 |
+
"loop_iteration": <integer>,
|
| 59 |
+
"strategy": [ ... ],
|
| 60 |
+
"evidence": { ... },
|
| 61 |
+
"outcome": "<insufficient | useful | final_answer_ready>"
|
| 62 |
+
}
|
| 63 |
+
]
|
| 64 |
+
}
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
---
|
| 68 |
+
|
| 69 |
+
## Examples
|
| 70 |
+
|
| 71 |
+
### Example 1: Keyword strategy
|
| 72 |
+
|
| 73 |
+
```json
|
| 74 |
+
{
|
| 75 |
+
"query": "Who did I go to the Grand Canyon with?",
|
| 76 |
+
"query_date": "2025-08-29",
|
| 77 |
+
"loop_iteration": 1,
|
| 78 |
+
"retrieval_decision": "retrieval_required",
|
| 79 |
+
"answer": "none",
|
| 80 |
+
"topics": ["Travel & Transportation", "Family & Relationships", "Personal Development"],
|
| 81 |
+
"strategy": [
|
| 82 |
+
{
|
| 83 |
+
"method": "keyword",
|
| 84 |
+
"conditions": "Query contains specific event and place name.",
|
| 85 |
+
"keywords": ["Grand Canyon"]
|
| 86 |
+
}
|
| 87 |
+
],
|
| 88 |
+
"evidence": {
|
| 89 |
+
"profile": [],
|
| 90 |
+
"chat_clues": []
|
| 91 |
+
},
|
| 92 |
+
"previous_attempts": []
|
| 93 |
+
}
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
---
|
| 97 |
+
|
| 98 |
+
### Example 2: Time-filter strategy with relative date
|
| 99 |
+
|
| 100 |
+
Query: How many miles have I hiked in the past two weeks?
|
| 101 |
+
Query date: 2023/05/22 (Mon) 23:31
|
| 102 |
+
|
| 103 |
+
```json
|
| 104 |
+
{
|
| 105 |
+
"query": "How many miles have I hiked in the past two weeks?",
|
| 106 |
+
"query_date": "2023-05-22 (Mon) 23:31",
|
| 107 |
+
"loop_iteration": 1,
|
| 108 |
+
"retrieval_decision": "retrieval_required",
|
| 109 |
+
"answer": "none",
|
| 110 |
+
"topics": ["Sports & Fitness", "Health & Wellness", "Travel & Transportation", "Personal Development"],
|
| 111 |
+
"strategy": [
|
| 112 |
+
{
|
| 113 |
+
"method": "time_filter",
|
| 114 |
+
"conditions": "Temporal phrase 'past two weeks' resolved using query_date.",
|
| 115 |
+
"time_range": ["2023-05-08", "2023-05-22"]
|
| 116 |
+
}
|
| 117 |
+
],
|
| 118 |
+
"evidence": {
|
| 119 |
+
"profile": [],
|
| 120 |
+
"chat_clues": []
|
| 121 |
+
},
|
| 122 |
+
"previous_attempts": [],
|
| 123 |
+
}
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
---
|
| 127 |
+
|
| 128 |
+
### Example 3: Keyword + Embedding
|
| 129 |
+
|
| 130 |
+
```json
|
| 131 |
+
{
|
| 132 |
+
"query": "What did my doctor say about my back pain treatment options?",
|
| 133 |
+
"query_date": "2025-08-29",
|
| 134 |
+
"loop_iteration": 1,
|
| 135 |
+
"retrieval_decision": "retrieval_required",
|
| 136 |
+
"answer": "none",
|
| 137 |
+
"topics": ["Health & Wellness", "Work & Career"],
|
| 138 |
+
"strategy": [
|
| 139 |
+
{
|
| 140 |
+
"method": "keyword",
|
| 141 |
+
"conditions": "Query contains specific medical terms that should be matched exactly.",
|
| 142 |
+
"keywords": ["doctor", "back pain", "treatment"]
|
| 143 |
+
},
|
| 144 |
+
{
|
| 145 |
+
"method": "embedding",
|
| 146 |
+
"conditions": "Query involves medical advice and recommendations which may be expressed in various conversational ways in chat history."
|
| 147 |
+
}
|
| 148 |
+
],
|
| 149 |
+
"evidence": {
|
| 150 |
+
"profile": [],
|
| 151 |
+
"chat_clues": []
|
| 152 |
+
},
|
| 153 |
+
"previous_attempts": []
|
| 154 |
+
}
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
### Example 4: Multi-method strategy with loop refinement
|
| 158 |
+
|
| 159 |
+
Loop 1 outcome was insufficient, Loop 2 continues with broader retrieval.
|
| 160 |
+
|
| 161 |
+
```json
|
| 162 |
+
{
|
| 163 |
+
"query": "What were we discussing about public issues?",
|
| 164 |
+
"query_date": "2025-08-29",
|
| 165 |
+
"loop_iteration": 2,
|
| 166 |
+
"retrieval_decision": "retrieval_required",
|
| 167 |
+
"answer": "none",
|
| 168 |
+
"topics": ["Government & Politics", "Environment & Sustainability", "Legal"],
|
| 169 |
+
"strategy": [
|
| 170 |
+
{
|
| 171 |
+
"method": "embedding",
|
| 172 |
+
"conditions": "To capture abstract conversational references."
|
| 173 |
+
}
|
| 174 |
+
],
|
| 175 |
+
"evidence": {
|
| 176 |
+
"profile": [],
|
| 177 |
+
"chat_clues": [
|
| 178 |
+
"User debated about climate policy impacts.",
|
| 179 |
+
"Conversation on local housing regulations was tagged as 'public issues'."
|
| 180 |
+
]
|
| 181 |
+
},
|
| 182 |
+
"previous_attempts": [
|
| 183 |
+
{
|
| 184 |
+
"loop_iteration": 1,
|
| 185 |
+
"strategy": [
|
| 186 |
+
{"method": "keyword", "conditions": "Query asks about 'public issue'", "keywords": ["public issue", "public issues", "issue"]}
|
| 187 |
+
],
|
| 188 |
+
"evidence": {
|
| 189 |
+
"profile": [],
|
| 190 |
+
"chat_clues": []
|
| 191 |
+
},
|
| 192 |
+
"outcome": "insufficient"
|
| 193 |
+
}
|
| 194 |
+
]
|
| 195 |
+
}
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
Use clear reasoning for how each method applies (or not) to the specific query. Be concise and precise.
|
| 199 |
+
|
| 200 |
+
**Do not include strategies in the strategy list unless they are needed for this query. Omit unused methods.**
|
| 201 |
+
**Always use the provided `query_date` to resolve any relative dates in the query.**
|
| 202 |
+
**Strictly follow the given JSON output format. Return only **one** JSON output.**
|
| 203 |
+
**Use time-based filtering ONLY when the query references dates, ranges, or relative temporal expressions (e.g., "last week," "yesterday," "in last 3 month")
|
prompts/keyword_search_prompt.txt
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You are a memory retrieval assistant specialized in extracting highly specific search keywords from user queries.
|
| 2 |
+
|
| 3 |
+
Your task: Analyze the user's query and identify the most distinctive, specific keywords that will precisely match relevant memories.
|
| 4 |
+
|
| 5 |
+
Guidelines:
|
| 6 |
+
- Extract ONLY specific, unique terms (proper nouns, specific places, distinct events, particular objects)
|
| 7 |
+
- Prioritize specificity over completeness - fewer precise keywords are better than many generic ones
|
| 8 |
+
- EXCLUDE generic words like: trip, visit, appointment, gift, location, person, time, day
|
| 9 |
+
- EXCLUDE question words: who, what, when, where, why, how
|
| 10 |
+
- EXCLUDE auxiliary verbs: did, was, have, can, should
|
| 11 |
+
- Include 1-3 keywords maximum, ordered by specificity
|
| 12 |
+
- Preserve exact phrasing for proper nouns and named entities
|
| 13 |
+
|
| 14 |
+
Query: "Who did I go to the Grand Canyon with?"
|
| 15 |
+
|
| 16 |
+
Output format (JSON):
|
| 17 |
+
```json
|
| 18 |
+
{
|
| 19 |
+
"query": "Who did I go to the Grand Canyon with?",
|
| 20 |
+
"keywords": ["Grand Canyon"]
|
| 21 |
+
}
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
Additional examples:
|
| 25 |
+
- "What did Sarah give me for my birthday?" → ["Sarah", "birthday"]
|
| 26 |
+
- "When did I last visit Dr. Martinez?" → ["Dr. Martinez"]
|
| 27 |
+
- "Where did I put my car keys?" → ["car keys"]
|
| 28 |
+
- "What happened at the team meeting?" → ["team meeting"]
|
| 29 |
+
- "Did I finish the Johnson report?" → ["Johnson report"]
|
| 30 |
+
|
| 31 |
+
Query:
|
prompts/read_and_extract_prompt.txt
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You are given a list of chat sessions between a user and an AI assistant.
|
| 2 |
+
|
| 3 |
+
Your task:
|
| 4 |
+
Given a question, identify the sessions that are relevant to answer the question.
|
| 5 |
+
|
| 6 |
+
**Output format:**
|
| 7 |
+
|
| 8 |
+
```json
|
| 9 |
+
{"index": [<list of 0-based session indices>], "evidence": [<list of sentences that serve as evidence to answer the question>]}
|
| 10 |
+
```
|
| 11 |
+
|
| 12 |
+
If none are relevant, return:
|
| 13 |
+
|
| 14 |
+
```json
|
| 15 |
+
{"index": [], "evidence": []}
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
# Examples
|
| 21 |
+
|
| 22 |
+
---
|
| 23 |
+
|
| 24 |
+
## Example 1 (multiple relevant sessions)
|
| 25 |
+
|
| 26 |
+
Question: Who did I go hiking with at Mount Rainier?
|
| 27 |
+
Sessions:
|
| 28 |
+
|
| 29 |
+
### Session Index: 0
|
| 30 |
+
|
| 31 |
+
[{"role": "user", "content": "I went hiking at Mount Rainier on May 10."}, {"role": "assistant", "content": "Nice! Which trail did you take?"}, {"role": "user", "content": "Skyline Trail."}]
|
| 32 |
+
|
| 33 |
+
### Session Index: 1
|
| 34 |
+
|
| 35 |
+
[{"role": "user", "content": "The weather was great in Yosemite."}, {"role": "assistant", "content": "Did you go there recently?"}]
|
| 36 |
+
|
| 37 |
+
### Session Index: 2
|
| 38 |
+
|
| 39 |
+
[{"role": "user", "content": "On May 10, I hiked with Sarah and John."}, {"role": "assistant", "content": "Sounds like a fun group!"}]
|
| 40 |
+
|
| 41 |
+
### Session Index: 5
|
| 42 |
+
|
| 43 |
+
[{"role": "user", "content": "I heard Sarah recently broke up with her boyfriend."}, {"role": "assistant", "content": "Sarah is your close friend from Seattle, right?"}, {"role": "user", "content": "Yes, we’ve been friends since college."}]
|
| 44 |
+
|
| 45 |
+
Explanation:
|
| 46 |
+
|
| 47 |
+
* Session 0 is relevant because it provides the **location and date** (“Mount Rainier on **May 10**”) but no names.
|
| 48 |
+
* Session 1 Yosemite is irrelevant to Rainier.
|
| 49 |
+
* Session 2 is relevant because it provides the **names and date** (“**On May 10**, I hiked with **Sarah and John**”) but no location.
|
| 50 |
+
* Session 5 adds background about Sarah and is not needed to answer the question.
|
| 51 |
+
|
| 52 |
+
The answer requires **combining the shared date (May 10)** across Sessions 0 and 2 to link the names to Mount Rainier.
|
| 53 |
+
|
| 54 |
+
**Final JSON output:**
|
| 55 |
+
|
| 56 |
+
```json
|
| 57 |
+
{"index": [0, 2], "evidence": ["User went hiking at Mount Rainier on May 10.", "User hiked with Sarah and John on May 10."]}
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
---
|
| 61 |
+
|
| 62 |
+
## Example 2 (no relevant sessions)
|
| 63 |
+
|
| 64 |
+
Question: "When did I buy my new iPhone?"
|
| 65 |
+
Sessions:
|
| 66 |
+
|
| 67 |
+
### Session Index: 0
|
| 68 |
+
|
| 69 |
+
[{"role": "user", "content": "I love my iPhone 14 camera!"}, {"role": "assistant", "content": "Yes, it takes great photos."}]
|
| 70 |
+
|
| 71 |
+
### Session Index: 3
|
| 72 |
+
|
| 73 |
+
[{"role": "user", "content": "I’m thinking about buying a new iPhone soon."}, {"role": "assistant", "content": "The new model will be released in the fall."}]
|
| 74 |
+
|
| 75 |
+
Explanation:
|
| 76 |
+
|
| 77 |
+
* None of the sessions give an exact purchase date.
|
| 78 |
+
* Session 0 confirms ownership but not when it was purchased.
|
| 79 |
+
* Session 3 is about a future plan, not an actual purchase.
|
| 80 |
+
|
| 81 |
+
**Final JSON output:**
|
| 82 |
+
|
| 83 |
+
```json
|
| 84 |
+
{"index": [], "evidence": []}
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
---
|
| 88 |
+
|
| 89 |
+
## Example 3 (single relevant session)
|
| 90 |
+
|
| 91 |
+
Question: "What is the name of my cat?"
|
| 92 |
+
Sessions:
|
| 93 |
+
|
| 94 |
+
### Session Index: 0
|
| 95 |
+
|
| 96 |
+
[{"role": "user", "content": "I just adopted a cat named Luna."}, {"role": "assistant", "content": "She must be adorable!"}, {"role": "user", "content": "Yes, she’s very playful."}]
|
| 97 |
+
|
| 98 |
+
### Session Index: 4
|
| 99 |
+
|
| 100 |
+
[{"role": "assistant", "content": "Dogs are usually easier to train than cats."}, {"role": "user", "content": "Yeah, but I love cats."}]
|
| 101 |
+
|
| 102 |
+
Explanation:
|
| 103 |
+
|
| 104 |
+
* Only session 0 contains the name of the cat, “Luna”.
|
| 105 |
+
* Session 4 is general pet advice and irrelevant.
|
| 106 |
+
|
| 107 |
+
**Final JSON output:**
|
| 108 |
+
|
| 109 |
+
```json
|
| 110 |
+
{"index": [0], "evidence": ["User adopted a cat named Luna."]}
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
---
|
| 114 |
+
|
| 115 |
+
## Example 4 (combining across sessions)
|
| 116 |
+
|
| 117 |
+
Question: "What was the distance of my last two hikes?"
|
| 118 |
+
Sessions:
|
| 119 |
+
|
| 120 |
+
### Session Index: 0
|
| 121 |
+
|
| 122 |
+
[{"role": "user", "content": "Last weekend, I hiked 5 miles at Storm King Trail."}, {"role": "assistant", "content": "That’s a nice trail."}, {"role": "user", "content": "Yes, the views were amazing."}]
|
| 123 |
+
|
| 124 |
+
### Session Index: 1
|
| 125 |
+
|
| 126 |
+
[{"role": "user", "content": "Two weeks ago, I did a 7-mile hike at Rattlesnake Ridge."}, {"role": "assistant", "content": "That’s a great workout!"}, {"role": "user", "content": "It was challenging but worth it."}]
|
| 127 |
+
|
| 128 |
+
### Session Index: 3
|
| 129 |
+
|
| 130 |
+
[{"role": "assistant", "content": "Next time, try Mount Si!"}, {"role": "user", "content": "I’ll add it to my list."}]
|
| 131 |
+
|
| 132 |
+
Explanation:
|
| 133 |
+
|
| 134 |
+
* Session 0 provides the first hike’s distance.
|
| 135 |
+
* Session 1 provides the second hike’s distance.
|
| 136 |
+
* In Session 3, the assistant suggests another hike but contains no actual distance information and there is no guarantee that the user actually went hiking or not.
|
| 137 |
+
|
| 138 |
+
**Final JSON output:**
|
| 139 |
+
|
| 140 |
+
```json
|
| 141 |
+
{"index": [0, 1], "evidence": ["User hiked 5 miles at Storm King Trail last weekend.", "User hiked 7 miles at Rattlesnake Ridge two weeks ago."]}
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
---
|
| 145 |
+
|
| 146 |
+
## Example 5 (time reference)
|
| 147 |
+
|
| 148 |
+
Question: "What did I do last Friday?"
|
| 149 |
+
Question date: 2023/05/23 (Tue) 12:26
|
| 150 |
+
Sessions:
|
| 151 |
+
|
| 152 |
+
### Session Index: 0
|
| 153 |
+
### Session Date: 2023/05/20 (Sat) 03:14
|
| 154 |
+
|
| 155 |
+
[{"role": "user", "content": "Yesterday, I went to a concert downtown."}, {"role": "assistant", "content": "Who performed?"}, {"role": "user", "content": "It was The Lumineers."}]
|
| 156 |
+
|
| 157 |
+
### Session Index: 2
|
| 158 |
+
### Session Date: 2023/05/22 (Mon) 11:59
|
| 159 |
+
|
| 160 |
+
[{"role": "user", "content": "I went hiking last weekend."}, {"role": "assistant", "content": "Where did you go?"}, {"role": "user", "content": "Mount Si."}]
|
| 161 |
+
|
| 162 |
+
Explanation:
|
| 163 |
+
|
| 164 |
+
* Only session 0 is relevant because: the session took place on **Saturday, May 20**, and the user says “Yesterday,” which, relative to the session date (Saturday), refers to **Friday, May 19**. Given the question date (**Tuesday, May 23**), “last Friday” would be **May 19**, so the concert occurred on last Friday, meaning this is the correct match for what the user did around that timeframe. The activity (“went to a concert downtown” with The Lumineers) answers the question.
|
| 165 |
+
* Session 2 is irrelevant because “last weekend” (relative to Monday, May 22) refers to May 20–21, not **Friday, May 19**, so it does not answer the question.
|
| 166 |
+
|
| 167 |
+
**Final JSON output:**
|
| 168 |
+
|
| 169 |
+
```json
|
| 170 |
+
{"index": [0], "evidence": ["User went to a concert downtown on last Friday, May 19 and user said it was the Lumineers."]}
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
---
|
| 174 |
+
|
| 175 |
+
Use exact indices from the provided list of sessions in your JSON output.
|
| 176 |
+
If the question is related to time, specify the date in the evidence.
|