anonymous-penguin commited on
Commit
9c60174
·
verified ·
1 Parent(s): 2455d57

Initial code release

Browse files

Code for memory-retrieval experiments.

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +163 -0
  2. baselines/MemoChat/LICENSE +21 -0
  3. baselines/MemoChat/README.md +47 -0
  4. baselines/MemoChat/code/codes/api/gpt_2k.py +90 -0
  5. baselines/MemoChat/code/codes/api/gpt_memochat.py +189 -0
  6. baselines/MemoChat/code/codes/api/llm_judge.py +96 -0
  7. baselines/MemoChat/code/codes/eval/eval_instruction_tuning_tasks.py +169 -0
  8. baselines/MemoChat/code/codes/eval/get_model_infer_memochat.py +291 -0
  9. baselines/MemoChat/code/codes/eval/get_model_infer_simple.py +150 -0
  10. baselines/MemoChat/code/codes/train/data_preprocess.py +118 -0
  11. baselines/MemoChat/code/codes/train/train.py +150 -0
  12. baselines/MemoChat/code/configs/ds_config_13b.json +53 -0
  13. baselines/MemoChat/code/configs/ds_config_33b.json +57 -0
  14. baselines/MemoChat/code/configs/ds_config_3b.json +39 -0
  15. baselines/MemoChat/code/configs/ds_config_7b.json +49 -0
  16. baselines/MemoChat/code/scripts/llm_judge.sh +35 -0
  17. baselines/MemoChat/code/scripts/memochat.sh +34 -0
  18. baselines/MemoChat/code/scripts/memochat_gpt.sh +18 -0
  19. baselines/MemoChat/code/scripts/tuning.sh +110 -0
  20. baselines/MemoChat/core_requirement.txt +13 -0
  21. baselines/MemoChat/run_memochat_baseline.py +634 -0
  22. baselines/raptor/LICENSE.txt +21 -0
  23. baselines/raptor/README.md +204 -0
  24. baselines/raptor/raptor/EmbeddingModels.py +37 -0
  25. baselines/raptor/raptor/FaissRetriever.py +201 -0
  26. baselines/raptor/raptor/QAModels.py +185 -0
  27. baselines/raptor/raptor/RetrievalAugmentation.py +306 -0
  28. baselines/raptor/raptor/Retrievers.py +8 -0
  29. baselines/raptor/raptor/SummarizationModels.py +74 -0
  30. baselines/raptor/raptor/__init__.py +16 -0
  31. baselines/raptor/raptor/cluster_tree_builder.py +151 -0
  32. baselines/raptor/raptor/cluster_utils.py +185 -0
  33. baselines/raptor/raptor/tree_builder.py +369 -0
  34. baselines/raptor/raptor/tree_retriever.py +327 -0
  35. baselines/raptor/raptor/tree_structures.py +28 -0
  36. baselines/raptor/raptor/utils.py +208 -0
  37. baselines/raptor/requirements.txt +11 -0
  38. baselines/raptor/run_raptor_baseline.py +511 -0
  39. baselines/read-agent/read_agent_demo.ipynb +976 -0
  40. baselines/read-agent/run_readagent_baseline.py +424 -0
  41. evaluate_qa.py +916 -0
  42. main.py +1717 -0
  43. memory/__init__.py +2 -0
  44. memory/episodic_store.py +62 -0
  45. memory/semantic_store.py +87 -0
  46. model_zoo.py +31 -0
  47. prompts/agentic_retrieval_prompt.txt +226 -0
  48. prompts/agentic_retrieval_prompt_wo_profile.txt +203 -0
  49. prompts/keyword_search_prompt.txt +31 -0
  50. prompts/read_and_extract_prompt.txt +176 -0
README.md ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Long-Term Memory Retrieval Benchmark
2
+
3
+ Code release for the experiments described in the accompanying paper:
4
+ - **Hierarchical memory** organization (User Profile / Semantic / Episodic).
5
+ - **Plan-Act-Read agentic retrieval** that interleaves keyword, time-filter,
6
+ and embedding search.
7
+ - **Flat / dense / oracle baselines** for comparison.
8
+
9
+ ## Repository layout
10
+
11
+ ```
12
+ .
13
+ ├── main.py # End-to-end QA pipeline (agent, embed, keyword modes)
14
+ ├── evaluate_qa.py # Atomic-rubric QA evaluator (strict + partial)
15
+ ├── model_zoo.py # Model registry
16
+ ├── prompts/ # Prompt templates
17
+ │ ├── agentic_retrieval_prompt.txt
18
+ │ ├── agentic_retrieval_prompt_wo_profile.txt
19
+ │ ├── keyword_search_prompt.txt
20
+ │ └── read_and_extract_prompt.txt
21
+ ├── memory/ # Episodic + semantic memory stores
22
+ ├── baselines/
23
+ │ ├── MemoChat/ # MemoChat baseline (upstream code + our wrapper)
24
+ │ ├── raptor/ # RAPTOR baseline (upstream code + our wrapper)
25
+ │ └── read-agent/ # ReadAgent baseline wrapper
26
+ ├── scripts/
27
+ │ ├── build_retrieval_cache.py # Pre-compute GTE-7B embeddings for the corpus
28
+ │ ├── make_v5_shards.py # Deterministic shard split by question_id
29
+ │ ├── merge_jsonl_by_dataset_order.py
30
+ │ ├── run_oracle_qa.py # Gold-session-only upper bound
31
+ │ ├── plot_main_results.py
32
+ │ ├── llm_judge_agreement.py
33
+ │ └── slurm/
34
+ │ ├── example_dense_retrieval.slurm
35
+ │ └── example_agentic_retrieval.slurm
36
+ └── requirements.txt
37
+ ```
38
+
39
+ The benchmark dataset (`evolv_mem_v5.json`) is released separately; place it
40
+ under `dataset/` along with the supporting files referenced by `main.py`
41
+ (`all_sessions.json`, `all_session_summary.json`, etc.).
42
+
43
+ ## Setup
44
+
45
+ ```bash
46
+ python -m venv .venv && source .venv/bin/activate
47
+ pip install -r requirements.txt
48
+ ```
49
+
50
+ ### API keys
51
+
52
+ The pipeline calls LLMs through three optional providers; set whichever you
53
+ plan to use:
54
+
55
+ | Provider | Env var | Flag |
56
+ |------------------------------------------------|----------------------|--------------|
57
+ | OpenAI-compatible inference API | `NV_API_KEY` | `--nvidia` |
58
+ | OpenAI-compatible LiteLLM proxy | `LITELLM_API_KEY` | `--tritonai` |
59
+ | Direct Anthropic API | `ANTHROPIC_API_KEY` | (default) |
60
+ | Azure OpenAI | `AZURE_OPENAI_KEY` | (default) |
61
+
62
+ Each `--<flag>` selects which client the pipeline uses; entries in
63
+ `model_zoo.py` are tagged accordingly.
64
+
65
+ ## Quick start
66
+
67
+ ### 1. Build the per-question retrieval cache (one-time)
68
+
69
+ ```bash
70
+ python scripts/build_retrieval_cache.py \
71
+ --dataset dataset/evolv_mem_v5.json \
72
+ --all_sessions dataset/all_sessions.json \
73
+ --out_dir response_cache/retrieval/
74
+ ```
75
+
76
+ ### 2. Shard the dataset for parallel runs
77
+
78
+ ```bash
79
+ python scripts/make_v5_shards.py \
80
+ --dataset dataset/evolv_mem_v5.json \
81
+ --ret_cache_jsonl response_cache/retrieval/flat-gte/v5_retrievallog_turn_flat-gte \
82
+ --out_dir output/shards/v5_run_nchunks10/ \
83
+ --num_shards 8
84
+ ```
85
+
86
+ ### 3. Run the QA pipeline
87
+
88
+ Flat dense retrieval @ top-k=20 (single shard, e.g. for smoke testing):
89
+
90
+ ```bash
91
+ export ret_cache="output/shards/v5_run_nchunks10/ret_cache/shard_00.jsonl"
92
+ python main.py \
93
+ --in_file output/shards/v5_run_nchunks10/dataset/shard_00.json \
94
+ --out_file output/shards/v5_run_nchunks10/dense_gte_topk20/part_00.jsonl \
95
+ --model_name gpt-5.5 \
96
+ --top_k 20 \
97
+ --n_chunks 10 \
98
+ --nvidia \
99
+ --all_sessions_file dataset/all_sessions.json \
100
+ --no_semantic \
101
+ --mode embed
102
+ ```
103
+
104
+ Agentic retrieval over hierarchical memory:
105
+
106
+ ```bash
107
+ python main.py \
108
+ --in_file output/shards/v5_run_nchunks10/dataset/shard_00.json \
109
+ --out_file output/shards/v5_run_nchunks10/agentic_hier/part_00.jsonl \
110
+ --model_name gpt-5.5 \
111
+ --top_k 20 \
112
+ --n_chunks 10 \
113
+ --nvidia \
114
+ --all_sessions_file dataset/all_sessions.json \
115
+ --hier_v2 --hier_union \
116
+ --mode agent
117
+ ```
118
+
119
+ To launch the full 8-shard parallel sweep on a SLURM cluster, edit and submit
120
+ `scripts/slurm/example_dense_retrieval.slurm` or
121
+ `scripts/slurm/example_agentic_retrieval.slurm`.
122
+
123
+ ### 4. Merge shards and evaluate
124
+
125
+ ```bash
126
+ python scripts/merge_jsonl_by_dataset_order.py \
127
+ --dataset dataset/evolv_mem_v5.json \
128
+ --parts_glob "output/shards/v5_run_nchunks10/dense_gte_topk20/part_*.jsonl" \
129
+ --out_file output/v5_run_dense_gte_topk20.jsonl
130
+
131
+ python evaluate_qa.py \
132
+ --hyp_file output/v5_run_dense_gte_topk20.jsonl \
133
+ --ref_file dataset/evolv_mem_v5.json \
134
+ --eval_model_name gpt-5.2 \
135
+ --eval_mode both \
136
+ --nvidia
137
+ ```
138
+
139
+ The evaluator caches an atomic-rubric per question
140
+ (`<dataset>.atomic-v1.rubric.json`) so subsequent runs reuse it.
141
+
142
+ ## Pipeline modes
143
+
144
+ `main.py --mode` selects how a question is answered:
145
+
146
+ - `embed`: top-k flat dense retrieval (GTE 7B), then a single LLM call to answer.
147
+ - `keyword`: LLM-generated keywords + lexical matching, then answer.
148
+ - `agent`: Plan-Act-Read loop. Combines `--hier_v2` (semantic-summary stage) and
149
+ `--hier_union` (union with flat top-K) for the hierarchical-memory variant.
150
+
151
+ `--no_semantic` disables the semantic-summary memory layer (flat memory).
152
+
153
+ ## Baselines
154
+
155
+ The three external baselines (MemoChat, RAPTOR, ReadAgent) live under
156
+ `baselines/` together with our thin wrappers
157
+ (`run_<baseline>_baseline.py`). Each baseline's upstream LICENSE is preserved.
158
+
159
+ ## License
160
+
161
+ This repository is released under the license stated in the corresponding
162
+ LICENSE file (TBD prior to release). Upstream baselines retain their original
163
+ licenses.
baselines/MemoChat/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Lu junru
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
baselines/MemoChat/README.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MemoChat
2
+ MemoChat: Tuning LLMs to Use Memos for Consistent Long-Range Open-domain Conversation
3
+
4
+ ## Environment
5
+ We provide [core_requirement.txt](core_requirement.txt) for your convenience.
6
+
7
+ ## Model Weights
8
+ The initial models we used are [fastchat models (v1.3)](https://lmsys.org/blog/2023-03-30-vicuna/). Below are the model weights of our fine-tuned version. Our models are built upon Fastchat modles, thus we adopt same `cc-by-nc-sa-4.0` license.
9
+
10
+ | Name | Share Link |
11
+ | --- | --- |
12
+ | MemoChat-Fastchat-T5-3B | https://huggingface.co/Junrulu/MemoChat-Fastchat-T5-3B |
13
+ | MemoChat-Vicuna-7B | https://huggingface.co/Junrulu/MemoChat-Vicuna-7B |
14
+ | MemoChat-Vicuna-13B | https://huggingface.co/Junrulu/MemoChat-Vicuna-13B |
15
+ | MemoChat-Vicuna-33B | https://huggingface.co/Junrulu/MemoChat-Vicuna-33B |
16
+
17
+ ## Workflow
18
+ `RootPath` is the absolute path of this repo. Download initial models and put them in [model](model) folder.
19
+ ### Instruction Tuning
20
+ ```
21
+ Run `bash code/scripts/tuning.sh RootPath`. Intermediate evaluation are included in this script as well.
22
+ ```
23
+
24
+ ### MemoChat Testing
25
+ ```
26
+ Run `bash code/scripts/memochat.sh RootPath` for pipeline testing with fine-tuned models.
27
+ Run `bash code/scripts/memochat_gpt.sh RootPath` for pipeline testing with GPT3.5 API.
28
+ Run `bash code/scripts/llm_judge.sh RootPath` for GPT4 judge (openai api is required).
29
+ ```
30
+
31
+ ### Our Results
32
+ We provide our prediction results [here](https://drive.google.com/file/d/1jGNhT3iPXEA8B2fXHZ2Einy1AMre-8xB/view?usp=sharing).
33
+
34
+ ## Acknowledgement
35
+ We thank [Vicuna project](https://github.com/lm-sys/FastChat/tree/main) for their great work.
36
+
37
+ ## Citation
38
+ ```
39
+ @misc{lu2023memochat,
40
+ title={MemoChat: Tuning LLMs to Use Memos for Consistent Long-Range Open-Domain Conversation},
41
+ author={Junru Lu and Siyu An and Mingbao Lin and Gabriele Pergola and Yulan He and Di Yin and Xing Sun and Yunsheng Wu},
42
+ year={2023},
43
+ eprint={2308.08239},
44
+ archivePrefix={arXiv},
45
+ primaryClass={cs.CL}
46
+ }
47
+ ```
baselines/MemoChat/code/codes/api/gpt_2k.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import openai
4
+ import time
5
+ import sys
6
+ import tiktoken
7
+
8
+ input_data = sys.argv[1]
9
+ openai_modelid = sys.argv[2]
10
+ openai.api_key = sys.argv[3]
11
+ output_path = sys.argv[4]
12
+ prompt_path = sys.argv[5]
13
+ encoding = tiktoken.encoding_for_model(openai_modelid)
14
+
15
+ q_pre = ""
16
+ qa_link = ""
17
+ MaxLen = 2048
18
+ TarLen = 512
19
+ TaskTarLen = {
20
+ "chatting_dialogsum": MaxLen,
21
+ "chatting_alpacagpt4": MaxLen,
22
+ "writing_topiocqa": TarLen // 2,
23
+ "writing_dialogsum": TarLen,
24
+ "retrieval_dialogsum": 32,
25
+ "retrieval_topiocqa": 32
26
+ }
27
+
28
+ prompts = json.load(open(prompt_path, "r"))
29
+
30
+ def normalize_chatting_outputs(model_outputs):
31
+ def white_space_fix(text):
32
+ lines = text.split("\n")
33
+ result = []
34
+ for line in lines:
35
+ result.append(' '.join(line.split()))
36
+ output = '\n'.join(result)
37
+ return output
38
+ return white_space_fix(model_outputs)
39
+
40
+ def gen_model_output(input_qs, task_type):
41
+ input_qs_token_l = len(encoding.encode(input_qs)) # token num
42
+ input_qs_word_l = len(input_qs.split(" ")) # word num
43
+ qs_w_t_ratio = input_qs_word_l / input_qs_token_l
44
+ max_word_num = int((MaxLen - TarLen) * qs_w_t_ratio)
45
+ input_qs = " ".join(input_qs.split(" ")[-max_word_num:])
46
+ target_len = TaskTarLen[task_type]
47
+ messages = [{"role": "system", "content": input_qs}]
48
+ for _ in range(5):
49
+ try:
50
+ chat = openai.ChatCompletion.create(
51
+ model=openai_modelid, messages=messages, max_tokens=target_len, temperature=0.2
52
+ )
53
+ break
54
+ except:
55
+ time.sleep(5)
56
+ model_outputs = chat.choices[0].message.content
57
+ return model_outputs
58
+
59
+ def run_eval():
60
+ data = json.load(open(input_data, "r"))
61
+ output_data = []
62
+ for d in data:
63
+ print("=" * 20 + "start of question {}".format(d["id"]) + "=" * 20)
64
+ new_d = d
65
+
66
+ history = []
67
+ for l_i in range(len(new_d["conversations"])):
68
+ if l_i % 2 == 1:
69
+ bot_thinking = {"retrieval": "", "summarization": ""}
70
+ print("=" * 20 + "start of turn {}".format(l_i // 2 + 1) + "=" * 20)
71
+ user = "user: " + new_d["conversations"][l_i - 1]["value"]
72
+
73
+ system_insturction = prompts["chatting"]["system"]
74
+ task_instruction = prompts["chatting"]["instruction"]
75
+ task_case = "```\nRecent Dialogs:\n" + " ### ".join([hrd.replace("\n", " ") for hrd in history]) + "\n```\n\nUser Input:\n" + user + " ### bot: "
76
+ qs = system_insturction + task_case + task_instruction
77
+ print(qs + "\n\n")
78
+ outputs = gen_model_output(qs, "chatting_dialogsum")
79
+ outputs = normalize_chatting_outputs(outputs)
80
+ history += [user, "bot: " + outputs]
81
+ print("bot: " + outputs + "\n")
82
+ print("=" * 20 + "end of turn {}".format(l_i // 2 + 1) + "=" * 20)
83
+ new_d["conversations"][l_i]["thinking"] = json.dumps(bot_thinking)
84
+ new_d["conversations"][l_i]["value"] = outputs
85
+
86
+ output_data.append(new_d)
87
+ json.dump(output_data, open(output_path, "w"), indent=2)
88
+
89
+ if __name__ == "__main__":
90
+ run_eval()
baselines/MemoChat/code/codes/api/gpt_memochat.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import openai
4
+ import time
5
+ import sys
6
+ import tiktoken
7
+ from random import sample
8
+
9
+ input_data = sys.argv[1]
10
+ openai_modelid = sys.argv[2]
11
+ openai.api_key = sys.argv[3]
12
+ output_path = sys.argv[4]
13
+ prompt_path = sys.argv[5]
14
+ encoding = tiktoken.encoding_for_model(openai_modelid)
15
+
16
+ q_pre = ""
17
+ qa_link = ""
18
+ MaxLen = 2048
19
+ TarLen = 512
20
+ TaskTarLen = {
21
+ "chatting_dialogsum": MaxLen,
22
+ "chatting_alpacagpt4": MaxLen,
23
+ "writing_topiocqa": TarLen // 2,
24
+ "writing_dialogsum": TarLen,
25
+ "retrieval_dialogsum": 32,
26
+ "retrieval_topiocqa": 32
27
+ }
28
+
29
+ prompts = json.load(open(prompt_path, "r"))
30
+
31
+ def normalize_model_outputs(model_text):
32
+ extracted_elements = [re.sub(r'\s+', ' ', mt.replace('"', '').replace("'", "")) for mt in re.findall(r"'[^']*'|\"[^\"]*\"|\d+", model_text)]
33
+ model_outputs = []
34
+ ti = 0
35
+ while ti + 7 < len(extracted_elements):
36
+ if extracted_elements[ti] == "topic" and extracted_elements[ti + 2] == "summary" and extracted_elements[ti + 4] == "start" and extracted_elements[ti + 6] == "end":
37
+ try:
38
+ model_outputs.append({"topic": extracted_elements[ti + 1], "summary": extracted_elements[ti + 3], "start": int(extracted_elements[ti + 5]), "end": int(extracted_elements[ti + 7])})
39
+ except:
40
+ pass
41
+ ti += 1
42
+ return model_outputs
43
+
44
+ def normalize_chatting_outputs(model_outputs):
45
+ def white_space_fix(text):
46
+ lines = text.split("\n")
47
+ result = []
48
+ for line in lines:
49
+ result.append(' '.join(line.split()))
50
+ output = '\n'.join(result)
51
+ return output
52
+ return white_space_fix(model_outputs)
53
+
54
+ def gen_model_output(input_qs, task_type):
55
+ input_qs_token_l = len(encoding.encode(input_qs)) # token num
56
+ input_qs_word_l = len(input_qs.split(" ")) # word num
57
+ qs_w_t_ratio = input_qs_word_l / input_qs_token_l
58
+ max_word_num = int((MaxLen - TarLen) * qs_w_t_ratio)
59
+ input_qs = " ".join(input_qs.split(" ")[-max_word_num:])
60
+ target_len = TaskTarLen[task_type]
61
+ messages = [{"role": "system", "content": input_qs}]
62
+ for _ in range(5):
63
+ try:
64
+ chat = openai.ChatCompletion.create(
65
+ model=openai_modelid, messages=messages, max_tokens=target_len, temperature=0.2
66
+ )
67
+ break
68
+ except:
69
+ time.sleep(5)
70
+ model_outputs = chat.choices[0].message.content
71
+ return model_outputs
72
+
73
+ def run_summary(history, memo, bot_thinking):
74
+ system_insturction = prompts["writing_dialogsum"]["system"]
75
+ task_instruction = prompts["writing_dialogsum"]["instruction"]
76
+ history_log = "\n\n```\nTask Conversation:\n" + "\n".join(["(line {}) {}".format(h_i + 1, h.replace("\n", " ")) for h_i, h in enumerate(history["Recent Dialogs"][2:])])
77
+ qs = q_pre + system_insturction.replace("LINE", str(len(history["Recent Dialogs"]) - 2)) + history_log + "\n```" + task_instruction.replace("LINE", str(len(history["Recent Dialogs"]) - 2)) + qa_link
78
+ # print("-" * 20 + "summarizing" + "-" * 20)
79
+ # print(qs)
80
+ # print("-" * 20 + "summarizing" + "-" * 20)
81
+ sum_history = gen_model_output(qs, "writing_dialogsum")
82
+ sum_history = normalize_model_outputs(sum_history)
83
+ # print("-" * 20 + "summarization" + "-" * 20)
84
+ # print(sum_history)
85
+ # print("-" * 20 + "summarization" + "-" * 20)
86
+ for s in sum_history:
87
+ memo[s["topic"]] = memo.get(s["topic"], []) + [{"summary": s["summary"], "dialogs": history["Recent Dialogs"][2:][(s["start"] - 1):s["end"]]}]
88
+ if len(sum_history) == 0:
89
+ si_0, si_1 = sample(list(range(len(history["Recent Dialogs"][2:]))), 2)
90
+ memo["NOTO"].append({"summary": "Partial dialogs about: {} or {}.".format(history["Recent Dialogs"][2:][si_0], history["Recent Dialogs"][2:][si_1]), "dialogs": history["Recent Dialogs"][2:]})
91
+ history["Recent Dialogs"] = history["Recent Dialogs"][-2:]
92
+ bot_thinking["summarization"] = {"input": qs, "output": sum_history}
93
+ return history, memo, bot_thinking
94
+
95
+ def run_retrieval(history, memo, bot_thinking):
96
+ topics = []
97
+ for k, v in memo.items():
98
+ for vv in v:
99
+ topics.append((k, vv["summary"], vv["dialogs"]))
100
+ system_insturction = prompts["retrieval"]["system"]
101
+ task_instruction = prompts["retrieval"]["instruction"]
102
+ task_case = "```\nQuery Sentence:\n" + history["User Input"][6:] + "\nTopic Options:\n" + \
103
+ "\n".join(["({}) {}".format(v_i + 1, v[0] + ". " + v[1]) for v_i, v in enumerate(topics)]) + "\n```"
104
+ qs = q_pre + system_insturction.replace("OPTION", str(len(topics))) + task_case + task_instruction.replace("OPTION", str(len(topics))) + qa_link
105
+ # print("-" * 20 + "retrieving" + "-" * 20)
106
+ # print(qs)
107
+ # print("-" * 20 + "retrieving" + "-" * 20)
108
+ outputs = gen_model_output(qs, "retrieval_dialogsum")
109
+ # print("-" * 20 + "retrieval" + "-" * 20)
110
+ # print(outputs)
111
+ # print("-" * 20 + "retrieval" + "-" * 20)
112
+ outputs = outputs.split("#")
113
+ chosen_topics = []
114
+ for output in outputs:
115
+ try:
116
+ index_ = int(output) - 1
117
+ except:
118
+ continue
119
+ if index_ < len(topics) and "NOTO" not in topics[index_]:
120
+ chosen_topics.append(topics[index_])
121
+ if len(chosen_topics) > 0:
122
+ history["Related Topics"] = [ct[0] for ct in chosen_topics]
123
+ history["Related Summaries"] = [ct[1] for ct in chosen_topics]
124
+ history["Related Dialogs"] = [" ### ".join(ct[2]) for ct in chosen_topics]
125
+ else:
126
+ history["Related Topics"] = []
127
+ history["Related Summaries"] = []
128
+ history["Related Dialogs"] = []
129
+ bot_thinking["retrieval"] = {"input": qs, "output": outputs}
130
+ return history, bot_thinking
131
+
132
+ def run_eval():
133
+ data = json.load(open(input_data, "r"))
134
+ output_data = []
135
+ for d in data:
136
+ print("=" * 20 + "start of question {}".format(d["id"]) + "=" * 20)
137
+ new_d = d
138
+
139
+ history = {
140
+ "Recent Dialogs": ["user: Hi!", "bot: Hi! How can I help you today?"],
141
+ "Related Topics": [],
142
+ "Related Summaries": [],
143
+ "Related Dialogs": [],
144
+ "User Input": "",
145
+ }
146
+ memo = {
147
+ "NOTO": [{"summary": "None of the others.", "dialogs": []}]
148
+ }
149
+
150
+ for l_i in range(len(new_d["conversations"])):
151
+ if l_i % 2 == 1:
152
+ bot_thinking = {"retrieval": "", "summarization": ""}
153
+ print("=" * 20 + "start of turn {}".format(l_i // 2 + 1) + "=" * 20)
154
+ user = "user: " + new_d["conversations"][l_i - 1]["value"]
155
+ print(user + "\n\n")
156
+
157
+ # create summary if recent dialogs exceed threshold
158
+ if len(" ### ".join(history["Recent Dialogs"]).split(" ")) > (MaxLen // 2) or len(history["Recent Dialogs"]) >= 10:
159
+ history, memo, bot_thinking = run_summary(history, memo, bot_thinking)
160
+
161
+ # retrieve most related topics for every new user input
162
+ history["User Input"] = user
163
+ if len(memo.keys()) > 1:
164
+ history, bot_thinking = run_retrieval(history, memo, bot_thinking)
165
+
166
+ # generate bot response
167
+ system_insturction = prompts["chatting"]["system"]
168
+ task_instruction = prompts["chatting"]["instruction"]
169
+ task_case = "```\nRelated Evidences:\n" + "\n".join(["({}) {}".format(r_tsd_i + 1, {
170
+ "Related Topics": history["Related Topics"][r_tsd_i],
171
+ "Related Summaries": history["Related Summaries"][r_tsd_i],
172
+ "Related Dialogs": history["Related Dialogs"][r_tsd_i]
173
+ }) for r_tsd_i in range(len(history["Related Topics"]))]) + "\n\nRecent Dialogs:\n" + \
174
+ " ### ".join([hrd.replace("\n", " ") for hrd in history["Recent Dialogs"]]) + "\n```\n\nUser Input:\n" + history["User Input"] + " ### bot: "
175
+ qs = q_pre + system_insturction + task_case + task_instruction + qa_link
176
+ outputs = gen_model_output(qs, "chatting_dialogsum")
177
+ outputs = normalize_chatting_outputs(outputs)
178
+ history["Recent Dialogs"] += [user, "bot: " + outputs]
179
+ print("bot: " + outputs + "\n")
180
+ print("=" * 20 + "end of turn {}".format(l_i // 2 + 1) + "=" * 20)
181
+ # print("\n\n\n\n")
182
+ new_d["conversations"][l_i]["thinking"] = json.dumps(bot_thinking)
183
+ new_d["conversations"][l_i]["value"] = outputs
184
+
185
+ output_data.append(new_d)
186
+ json.dump(output_data, open(output_path, "w"), indent=2)
187
+
188
+ if __name__ == "__main__":
189
+ run_eval()
baselines/MemoChat/code/codes/api/llm_judge.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import json
4
+ import re
5
+ import openai
6
+ import time
7
+ import tiktoken
8
+
9
+ input_data = sys.argv[1]
10
+ openai_modelid = sys.argv[2]
11
+ openai.api_key = sys.argv[3]
12
+ output_path = sys.argv[4]
13
+ prompt_path = sys.argv[5]
14
+ encoding = tiktoken.encoding_for_model(openai_modelid)
15
+
16
+ prompts = json.load(open(prompt_path, "r"))
17
+ judge_prompt_raw = prompts["judge"]["system"]
18
+
19
+ def gen_model_output(input_qs):
20
+ input_qs_token_l = len(encoding.encode(input_qs)) # token num
21
+ input_qs_word_l = len(input_qs.split(" ")) # word num
22
+ qs_w_t_ratio = input_qs_word_l / input_qs_token_l
23
+ max_word_num = int(4096 * qs_w_t_ratio)
24
+ input_qs = " ".join(input_qs.split(" ")[-max_word_num:])
25
+ messages = [{"role": "system", "content": input_qs}]
26
+ chat = None
27
+ for _ in range(5):
28
+ try:
29
+ chat = openai.ChatCompletion.create(
30
+ model=openai_modelid, messages=messages
31
+ )
32
+ break
33
+ except:
34
+ time.sleep(5)
35
+ if chat is None:
36
+ return "Cannot generate output."
37
+ model_outputs = chat.choices[0].message.content
38
+ return model_outputs
39
+
40
+ data = json.load(open(input_data, "r"))
41
+
42
+ # do llm judge
43
+ output_ratings = []
44
+ for d in data:
45
+ print("=" * 20 + "Processing: " + d["id"] + "=" * 20)
46
+ judge_conversation = []
47
+ d_conversations = d['conversations']
48
+ last_q = d_conversations[-2]
49
+ turn_infos = last_q["turn-info"].split("-")
50
+ r_turns = [turn_infos[0] + "-" + turn_infos[1]]
51
+ if len(turn_infos) == 5:
52
+ r_turns.append(turn_infos[2] + "-" + turn_infos[3])
53
+ for l_i in range(len(d_conversations) // 2 - 1):
54
+ if d_conversations[l_i * 2]["turn-info"][:-2] in r_turns:
55
+ judge_conversation.append("user: " + d_conversations[l_i * 2]["value"])
56
+ judge_conversation.append("bot: " + d_conversations[l_i * 2 + 1]["value"])
57
+ judge_prompt = judge_prompt_raw.replace("RCH_0", "\n".join(judge_conversation)).replace("UQ_1", "user: " + last_q["value"]).replace("BR_2", "bot: " + d_conversations[-1]["value"])
58
+ print(judge_prompt)
59
+ print('-' * 20)
60
+ outputs = gen_model_output(judge_prompt)
61
+ print(outputs)
62
+ print("=" * 20 + "Processed: " + d["id"] + "=" * 20)
63
+ match = re.search(r'\[\[(\d+)\]\]', outputs)
64
+ try:
65
+ rating = int(match.group(1))
66
+ except:
67
+ rating = None
68
+ output_ratings.append({
69
+ "id": d["id"],
70
+ "type": d["type"],
71
+ "judge_prompt": judge_prompt,
72
+ "evaluation": outputs,
73
+ "rating": rating
74
+ })
75
+ json.dump(output_ratings, open(output_path, "w"), indent=2)
76
+
77
+ # compute score
78
+ count = {
79
+ "continuation": [],
80
+ "retrospection": [],
81
+ "conjunction": []
82
+ }
83
+ for d in output_ratings:
84
+ if d["type"] == "continuation":
85
+ count["continuation"].append(d["rating"])
86
+ elif d["type"] == "retrospection":
87
+ count["retrospection"].append(d["rating"])
88
+ elif d["type"] == "conjunction":
89
+ count["conjunction"].append(d["rating"])
90
+ print("Retrospection Score: {}, Continuation Score: {}, Conjunction Score: {}, Overall Score: {} of file {}".format(
91
+ round(sum(count["retrospection"]) / len(count["retrospection"]), 2),
92
+ round(sum(count["continuation"]) / len(count["continuation"]), 2),
93
+ round(sum(count["conjunction"]) / len(count["conjunction"]), 2),
94
+ round(sum(count["continuation"] + count["retrospection"] + count["conjunction"]) / len(count["continuation"] + count["retrospection"] + count["conjunction"]), 2),
95
+ input_data
96
+ ))
baselines/MemoChat/code/codes/eval/eval_instruction_tuning_tasks.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import string
4
+ import sys
5
+ import random
6
+ from argparse import ArgumentParser
7
+ from collections import Counter
8
+
9
+ from evaluate import load
10
+ bertscore = load("bertscore")
11
+
12
+ refer_file_path = sys.argv[1]
13
+ input_file_path = sys.argv[2]
14
+
15
+ conversations = open(refer_file_path, "r").readlines()
16
+ conversations_dict = {}
17
+ for conversation in conversations:
18
+ conv_l = json.loads(conversation.strip())
19
+ conversations_dict[conv_l["question_id"]] = (conv_l["text"], conv_l["answer"], conv_l["type"])
20
+
21
+ class Metrics():
22
+ def __init__(self):
23
+ pass
24
+
25
+ def __normalize_text(self, s_text):
26
+ """Lower text and remove punctuation, storys and extra whitespace."""
27
+ def remove_articles(text):
28
+ regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
29
+ return re.sub(regex, ' ', text)
30
+
31
+ def white_space_fix(text):
32
+ return ' '.join(text.split())
33
+
34
+ def remove_punc(text):
35
+ exclude = set(string.punctuation)
36
+ return ''.join(ch for ch in text if ch not in exclude)
37
+
38
+ def lower(text):
39
+ return text.lower()
40
+
41
+ return white_space_fix(remove_articles(remove_punc(lower(s_text))))
42
+
43
+ def __normalize_model_outputs(self, model_text, type_category):
44
+ """post process of memo writing outputs"""
45
+ extracted_elements = [re.sub(r'\s+', ' ', mt.replace('"', '').replace("'", "")) for mt in re.findall(r"'[^']*'|\"[^\"]*\"|\d+", model_text)]
46
+ model_outputs = []
47
+ ti = 0
48
+
49
+ if "dialogsum" in type_category:
50
+ while ti + 7 < len(extracted_elements):
51
+ if extracted_elements[ti] == "topic" and extracted_elements[ti + 2] == "summary" and extracted_elements[ti + 4] == "start" and extracted_elements[ti + 6] == "end":
52
+ try:
53
+ model_outputs.append({"topic": extracted_elements[ti + 1], "summary": extracted_elements[ti + 3], "start": int(extracted_elements[ti + 5]), "end": int(extracted_elements[ti + 7])})
54
+ except:
55
+ pass
56
+ ti += 1
57
+ else:
58
+ while ti + 5 < len(extracted_elements):
59
+ if extracted_elements[ti] == "topic" and extracted_elements[ti + 2] == "start" and extracted_elements[ti + 4] == "end":
60
+ try:
61
+ model_outputs.append({"topic": extracted_elements[ti + 1], "start": int(extracted_elements[ti + 3]), "end": int(extracted_elements[ti + 5])})
62
+ except:
63
+ pass
64
+ ti += 1
65
+
66
+ return model_outputs
67
+
68
+ def __get_class_span_dict__(self, label, checkitem_k):
69
+ class_span = {}
70
+ for i in range(len(label)):
71
+ checkitem_i = self.__normalize_text(label[i][checkitem_k])
72
+ class_span[(label[i]['start'], label[i]['end'])] = class_span.get((label[i]['start'], label[i]['end']), []) + [checkitem_i]
73
+ return class_span
74
+
75
+ def __get_intersect_by_entity__(self, pred_class_span, label_class_span):
76
+ '''
77
+ return the count of correct entity
78
+ '''
79
+ cnt = 0
80
+ for label in label_class_span:
81
+ cnt += len(list(set(label_class_span[label]).intersection(set(pred_class_span.get(label,[])))))
82
+ return cnt
83
+
84
+ def __get_bertscore_by_entity__(self, pred_class_span, label_class_span):
85
+ '''
86
+ return the count of correct entity
87
+ '''
88
+ cnt = 0
89
+ for label in label_class_span:
90
+ if label in pred_class_span:
91
+ references = [label_class_span[label]]
92
+ prediction = [pred_class_span[label][0]]
93
+ result = bertscore.compute(predictions=prediction, references=references, model_type="microsoft/deberta-xlarge-mnli")["precision"][0]
94
+ cnt += result
95
+ return cnt
96
+
97
+ def __get_cnt__(self, label_class_span):
98
+ '''
99
+ return the count of entities
100
+ '''
101
+ cnt = 0
102
+ for label in label_class_span:
103
+ cnt += len(label_class_span[label])
104
+ # cnt += 1 # set as 1 if we have multiple references
105
+ return cnt
106
+
107
+ def metrics_by_entity_(self, pred, label, checkitem_k):
108
+ '''
109
+ return entity level count of total prediction, true labels, and correct prediction
110
+ '''
111
+ pred_class_span = self.__get_class_span_dict__(pred, checkitem_k)
112
+ label_class_span = self.__get_class_span_dict__(label, checkitem_k)
113
+ pred_cnt = self.__get_cnt__(pred_class_span)
114
+ label_cnt = self.__get_cnt__(label_class_span)
115
+ if checkitem_k == "topic":
116
+ correct_cnt = self.__get_intersect_by_entity__(pred_class_span, label_class_span)
117
+ elif checkitem_k == "summary":
118
+ correct_cnt = self.__get_bertscore_by_entity__(pred_class_span, label_class_span)
119
+ return pred_cnt, label_cnt, correct_cnt
120
+
121
+ def p_r_f1_by_entity(self, pc, lc, cc):
122
+ precision = cc / (pc + 1e-8)
123
+ recall = cc / (lc + 1e-8)
124
+ f1 = 2 * precision * recall / (precision + recall + 1e-8)
125
+ return round(precision * 100, 2), round(recall * 100, 2), round(f1 * 100, 2)
126
+
127
+ def metrics_by_entity_files(self, pred_file, checkitem_k, type_key):
128
+ pred_cnt = 0
129
+ label_cnt = 0
130
+ correct_cnt = 0
131
+ for l_i, line in enumerate(open(pred_file, "r").readlines()):
132
+ eles = json.loads(line.strip())
133
+
134
+ if (type_key not in conversations_dict[eles["question_id"]][2]) or (conversations_dict[eles["question_id"]][2] == "writing_topiocqa" and checkitem_k == "summary"):
135
+ continue
136
+ if type_key == "writing":
137
+ model_text = self.__normalize_model_outputs(eles["text"], conversations_dict[eles["question_id"]][2])
138
+ label_i = json.loads(conversations_dict[eles["question_id"]][1])
139
+ elif type_key == "retrieval":
140
+ model_text = [{"topic": v, "start": 0, "end": 0} for v in set(eles["text"].split("#"))]
141
+ label_i = [{"topic": v, "start": 0, "end": 0} for v in set(conversations_dict[eles["question_id"]][1].split("#"))]
142
+ else:
143
+ model_text = [{"summary": eles["text"], "start": 0, "end": 0}]
144
+ label_i = [{"summary": conversations_dict[eles["question_id"]][1], "start": 0, "end": 0}]
145
+
146
+ p_cnt, l_cnt, c_cnt = self.metrics_by_entity_(model_text, label_i, checkitem_k)
147
+ p_i, r_i, f_i = self.p_r_f1_by_entity(p_cnt, l_cnt, c_cnt)
148
+ # if p_i + r_i + f_i != 0:
149
+ # print("Q ID: " + str(eles["question_id"]) + "\n")
150
+ # print(conversations_dict[eles["question_id"]][0] + "\n")
151
+ # # print("Raw Ouput: " + eles["text"] + "\n")
152
+ # print("Model: {}".format(model_text) + "\n")
153
+ # print("Refer: {}".format(label_i) + "\n")
154
+ # print("Case P/R/F1: {}%, {}%, {}%".format(p_i, r_i, f_i))
155
+ # print("=" * 20)
156
+ pred_cnt += p_cnt
157
+ label_cnt += l_cnt
158
+ correct_cnt += c_cnt
159
+ return self.p_r_f1_by_entity(pred_cnt, label_cnt, correct_cnt)
160
+
161
+ calculate_metrics = Metrics()
162
+ p_a, r_a, f1_a = calculate_metrics.metrics_by_entity_files(input_file_path, 'topic', 'writing') # both
163
+ print("Overall P/R/F1 of topic: {}%, {}%, {}%".format(p_a, r_a, f1_a))
164
+ p_b, r_b, f1_b = calculate_metrics.metrics_by_entity_files(input_file_path, 'summary', 'writing') # dialogsum
165
+ print("Overall P/R/F1 of summary: {}%, {}%, {}%".format(p_b, r_b, f1_b))
166
+ _, _, f1 = calculate_metrics.metrics_by_entity_files(input_file_path, "topic", "retrieval") # both
167
+ print("Retrival F1: {}%".format(f1))
168
+ p, _, _ = calculate_metrics.metrics_by_entity_files(input_file_path, "summary", "chatting") # dialogsum
169
+ print("Chatting similarity: {}%".format(p))
baselines/MemoChat/code/codes/eval/get_model_infer_memochat.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, AutoModelForSeq2SeqLM
3
+ from optimum.bettertransformer import BetterTransformer
4
+ import torch
5
+ import os
6
+ import json
7
+ import re
8
+ import ray
9
+ import warnings
10
+ from random import sample
11
+ warnings.filterwarnings("ignore")
12
+
13
+ q_pre = "<s>\n"
14
+ qa_link = "\n"
15
+ MaxLen = 2048
16
+ TarLen = 512
17
+ TaskTarLen = {
18
+ "chatting_dialogsum": MaxLen,
19
+ "chatting_alpacagpt4": MaxLen,
20
+ "writing_topiocqa": TarLen // 2,
21
+ "writing_dialogsum": TarLen,
22
+ "retrieval_dialogsum": 32,
23
+ "retrieval_topiocqa": 32
24
+ }
25
+
26
+ def get_gpu_memory(num_gpus):
27
+ """Get available memory for each GPU."""
28
+ gpu_memory = []
29
+ for gpu_id in range(num_gpus):
30
+ with torch.cuda.device(gpu_id):
31
+ device = torch.cuda.current_device()
32
+ gpu_properties = torch.cuda.get_device_properties(device)
33
+ total_memory = gpu_properties.total_memory / (1024**3)
34
+ allocated_memory = torch.cuda.memory_allocated() / (1024**3)
35
+ available_memory = total_memory - allocated_memory
36
+ gpu_memory.append(available_memory)
37
+ return gpu_memory
38
+
39
+ def normalize_model_outputs(model_text):
40
+ """post processing function of memo writing task"""
41
+ extracted_elements = [re.sub(r'\s+', ' ', mt.replace('"', '').replace("'", "")) for mt in re.findall(r"'[^']*'|\"[^\"]*\"|\d+", model_text)]
42
+ model_outputs = []
43
+ ti = 0
44
+ while ti + 7 < len(extracted_elements):
45
+ if extracted_elements[ti] == "topic" and extracted_elements[ti + 2] == "summary" and extracted_elements[ti + 4] == "start" and extracted_elements[ti + 6] == "end":
46
+ try:
47
+ model_outputs.append({"topic": extracted_elements[ti + 1], "summary": extracted_elements[ti + 3], "start": int(extracted_elements[ti + 5]), "end": int(extracted_elements[ti + 7])})
48
+ except:
49
+ pass
50
+ ti += 1
51
+ return model_outputs
52
+
53
+ def normalize_chatting_outputs(model_outputs):
54
+ """post processing function of chatting response"""
55
+ def white_space_fix(text):
56
+ lines = text.split("\n")
57
+ result = []
58
+ for line in lines:
59
+ result.append(' '.join(line.split()))
60
+ output = '\n'.join(result)
61
+ return output
62
+ return white_space_fix(model_outputs)
63
+
64
+ def gen_model_output(model_path, model, tokenizer, input_qs, local_check, task_type):
65
+ if local_check:
66
+ from faker import Faker
67
+ fake = Faker(locale="en")
68
+ return fake.text(2000)
69
+ if "writing" in task_type:
70
+ eos_token_ids = [tokenizer.eos_token_id, tokenizer.encode("]", add_special_tokens=False)[0]]
71
+ elif "retrieval" in task_type:
72
+ eos_token_ids = [tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[0], tokenizer.encode(" ", add_special_tokens=False)[0]]
73
+ else:
74
+ eos_token_ids = [tokenizer.eos_token_id]
75
+ if "t5" in model_path:
76
+ # t5 model may need larger repetition penalty value to help get generation stability
77
+ input_ids = tokenizer([input_qs], max_length=MaxLen, truncation=True, add_special_tokens=False).input_ids
78
+ target_len = TaskTarLen[task_type]
79
+ repetition_penalty_value = 1.0
80
+ else:
81
+ input_ids = tokenizer([input_qs], max_length=(MaxLen - TarLen), truncation=True, add_special_tokens=False).input_ids
82
+ target_len = min(len(input_ids[0]) + TaskTarLen[task_type], MaxLen)
83
+ repetition_penalty_value = 1.0
84
+ output_ids = model.generate(
85
+ torch.as_tensor(input_ids).cuda(),
86
+ do_sample=True,
87
+ temperature=0.2,
88
+ max_length=target_len,
89
+ eos_token_id=eos_token_ids,
90
+ repetition_penalty=repetition_penalty_value
91
+ )
92
+ if "t5" in model_path:
93
+ output_ids = output_ids[0]
94
+ else:
95
+ output_ids = output_ids[0][len(input_ids[0]):]
96
+ model_outputs = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
97
+ return model_outputs
98
+
99
+ def run_summary(history, model_path, model, tokenizer, memo, local_check, bot_thinking, prompts):
100
+ """We assume there's no too long input from user, e.g. over 1500 tokens"""
101
+ system_insturction = prompts["writing_dialogsum"]["system"]
102
+ task_instruction = prompts["writing_dialogsum"]["instruction"]
103
+ history_log = "\n\n```\nTask Conversation:\n" + "\n".join(["(line {}) {}".format(h_i + 1, h.replace("\n", " ")) for h_i, h in enumerate(history["Recent Dialogs"][2:])])
104
+ qs = q_pre + system_insturction.replace("LINE", str(len(history["Recent Dialogs"]) - 2)) + history_log + "\n```" + task_instruction.replace("LINE", str(len(history["Recent Dialogs"]) - 2)) + qa_link
105
+ print("#" * 20 + "summarizing" + "#" * 20)
106
+ print(qs)
107
+ print("#" * 20 + "summarizing" + "#" * 20)
108
+ sum_history = gen_model_output(model_path, model, tokenizer, qs, local_check, "writing_dialogsum")
109
+ sum_history = normalize_model_outputs(sum_history)
110
+ print("#" * 20 + "summarization" + "#" * 20)
111
+ print(sum_history)
112
+ print("#" * 20 + "summarization" + "#" * 20)
113
+ for s in sum_history:
114
+ memo[s["topic"]] = memo.get(s["topic"], []) + [{"summary": s["summary"], "dialogs": history["Recent Dialogs"][2:][(s["start"] - 1):s["end"]]}]
115
+ if local_check:
116
+ memo["test_topic{}".format(len(memo.keys()))] = [{"summary": "test_summary{}".format(len(memo.keys())), "dialogs": history["Recent Dialogs"][2:][2:4]}]
117
+ if len(sum_history) == 0:
118
+ si_0, si_1 = sample(list(range(len(history["Recent Dialogs"][2:]))), 2)
119
+ memo["NOTO"].append({"summary": "Partial dialogs about: {} or {}.".format(history["Recent Dialogs"][2:][si_0], history["Recent Dialogs"][2:][si_1]), "dialogs": history["Recent Dialogs"][2:]})
120
+ history["Recent Dialogs"] = history["Recent Dialogs"][-2:]
121
+ bot_thinking["summarization"] = {"input": qs, "output": sum_history}
122
+ return history, memo, bot_thinking
123
+
124
+ def run_retrieval(history, model_path, model, tokenizer, memo, local_check, bot_thinking, prompts):
125
+ topics = []
126
+ for k, v in memo.items():
127
+ for vv in v:
128
+ topics.append((k, vv["summary"], vv["dialogs"]))
129
+ system_insturction = prompts["retrieval"]["system"]
130
+ task_instruction = prompts["retrieval"]["instruction"]
131
+ task_case = "```\nQuery Sentence:\n" + history["User Input"][6:] + "\nTopic Options:\n" + \
132
+ "\n".join(["({}) {}".format(v_i + 1, v[0] + ". " + v[1]) for v_i, v in enumerate(topics)]) + "\n```"
133
+ qs = q_pre + system_insturction.replace("OPTION", str(len(topics))) + task_case + task_instruction.replace("OPTION", str(len(topics))) + qa_link
134
+ print("#" * 20 + "retrieving" + "#" * 20)
135
+ print(qs)
136
+ print("#" * 20 + "retrieving" + "#" * 20)
137
+ outputs = gen_model_output(model_path, model, tokenizer, qs, local_check, "retrieval_dialogsum")
138
+ print("#" * 20 + "retrieval" + "#" * 20)
139
+ print(outputs)
140
+ print("#" * 20 + "retrieval" + "#" * 20)
141
+ outputs = outputs.split("#")
142
+ chosen_topics = []
143
+ for output in outputs:
144
+ try:
145
+ index_ = int(output) - 1
146
+ except:
147
+ continue
148
+ if index_ < len(topics) and "NOTO" not in topics[index_][0]:
149
+ chosen_topics.append(topics[index_])
150
+ if local_check:
151
+ chosen_topics = sample(topics, min(len(topics) - 1, 2))
152
+ if len(chosen_topics) > 0:
153
+ history["Related Topics"] = [ct[0] for ct in chosen_topics]
154
+ history["Related Summaries"] = [ct[1] for ct in chosen_topics]
155
+ history["Related Dialogs"] = [" ### ".join(ct[2]) for ct in chosen_topics]
156
+ else:
157
+ history["Related Topics"] = []
158
+ history["Related Summaries"] = []
159
+ history["Related Dialogs"] = []
160
+ bot_thinking["retrieval"] = {"input": qs, "output": outputs}
161
+ return history, bot_thinking
162
+
163
+ @torch.inference_mode()
164
+ def get_model_answers(model_path, num_gpus, local_check, load_in_8bit, ques_jsons, prompts):
165
+ model_path = os.path.expanduser(model_path)
166
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, truncation_side='left')
167
+
168
+ if not local_check:
169
+ # We assume you have enough GPUs to load one model, but you can modify the gpu_memory_dict to allow CPU offloads
170
+ # We recommend to use as less as GPUs possible to load one model for faster inference, such as 3 V100 32G or 2 A100 40G for 33B model
171
+ available_gpu_memory = get_gpu_memory(num_gpus)
172
+ gpu_memory_dict = {i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" for i in range(num_gpus)}
173
+ gpu_memory_dict["cpu"] = "0GiB"
174
+
175
+ if "t5" in model_path:
176
+ model = AutoModelForSeq2SeqLM.from_pretrained(
177
+ model_path, torch_dtype=torch.float16, device_map="auto", max_memory=gpu_memory_dict, load_in_8bit=load_in_8bit
178
+ )
179
+ else:
180
+ model = AutoModelForCausalLM.from_pretrained(
181
+ model_path, torch_dtype=torch.float16, device_map="auto", max_memory=gpu_memory_dict, load_in_8bit=load_in_8bit
182
+ )
183
+
184
+ # Initialize with BetterTransformer, injecting Flash-Attention
185
+ model = BetterTransformer.transform(model)
186
+
187
+ # turn on eval mode to stop batch normalizarion & dropout, can work together with torch.inference_mode
188
+ model = model.eval()
189
+ else:
190
+ model = None
191
+
192
+ output_data = []
193
+ for d in ques_jsons:
194
+ new_d = d
195
+
196
+ history = {
197
+ "Recent Dialogs": ["user: Hi!", "bot: Hi! How can I help you today?"],
198
+ "Related Topics": [],
199
+ "Related Summaries": [],
200
+ "Related Dialogs": [],
201
+ "User Input": "",
202
+ }
203
+ memo = {
204
+ "NOTO": [{"summary": "None of the others.", "dialogs": []}]
205
+ }
206
+
207
+ for l_i in range(len(new_d["conversations"])):
208
+ if l_i % 2 == 1:
209
+ bot_thinking = {"retrieval": "", "summarization": ""}
210
+ print("=" * 20 + "start of turn {}".format(l_i // 2 + 1) + "=" * 20)
211
+ user = "user: " + new_d["conversations"][l_i - 1]["value"]
212
+ print(user + "\n\n")
213
+
214
+ # create summary if recent dialogs exceed threshold
215
+ if len(" ### ".join(history["Recent Dialogs"]).split(" ")) > (MaxLen // 2) or len(history["Recent Dialogs"]) >= 10:
216
+ history, memo, bot_thinking = run_summary(history, model_path, model, tokenizer, memo, local_check, bot_thinking, prompts)
217
+
218
+ # retrieve most related topics for every new user input
219
+ history["User Input"] = user
220
+ if len(memo.keys()) > 1:
221
+ history, bot_thinking = run_retrieval(history, model_path, model, tokenizer, memo, local_check, bot_thinking, prompts)
222
+
223
+ # generate bot response
224
+ system_insturction = prompts["chatting"]["system"]
225
+ task_instruction = prompts["chatting"]["instruction"]
226
+ task_case = "```\nRelated Evidences:\n" + "\n".join(["({}) {}".format(r_tsd_i + 1, {
227
+ "Related Topics": history["Related Topics"][r_tsd_i],
228
+ "Related Summaries": history["Related Summaries"][r_tsd_i],
229
+ "Related Dialogs": history["Related Dialogs"][r_tsd_i]
230
+ }) for r_tsd_i in range(len(history["Related Topics"]))]) + "\n\nRecent Dialogs:\n" + \
231
+ " ### ".join([hrd.replace("\n", " ") for hrd in history["Recent Dialogs"]]) + "\n```\n\nUser Input:\n" + history["User Input"] + " ### bot: "
232
+ qs = q_pre + system_insturction + task_case + task_instruction + qa_link
233
+ outputs = gen_model_output(model_path, model, tokenizer, qs, local_check, "chatting_dialogsum")
234
+ outputs = normalize_chatting_outputs(outputs)
235
+ history["Recent Dialogs"] += [user, "bot: " + outputs]
236
+ print("bot: " + outputs + "\n")
237
+ print("=" * 20 + "end of turn {}".format(l_i // 2 + 1) + "=" * 20)
238
+ print("\n\n\n\n")
239
+ new_d["conversations"][l_i]["thinking"] = json.dumps(bot_thinking)
240
+ new_d["conversations"][l_i]["value"] = outputs
241
+
242
+ output_data.append(d)
243
+ return output_data
244
+
245
+ def run_eval(model_path, num_gpus, local_check, load_in_8bit, question_file, ray_num_gpus, answer_file, prompt_path):
246
+ assert num_gpus % ray_num_gpus == 0
247
+ prompts = json.load(open(prompt_path, "r"))
248
+
249
+ # split question file into num_gpus files
250
+ ques_jsons = json.load(open(os.path.expanduser(question_file), "r"))
251
+
252
+ chunk_size = len(ques_jsons) // (num_gpus // ray_num_gpus)
253
+ ans_handles = []
254
+ for i in range(0, len(ques_jsons), chunk_size):
255
+ get_answers_func = ray.remote(num_gpus=ray_num_gpus)(
256
+ get_model_answers
257
+ ).remote
258
+ ans_handles.append(
259
+ get_answers_func(
260
+ model_path, ray_num_gpus, local_check, load_in_8bit, ques_jsons[i: i + chunk_size], prompts
261
+ )
262
+ )
263
+
264
+ ans_jsons = []
265
+ for ans_handle in ans_handles:
266
+ ans_jsons.extend(ray.get(ans_handle))
267
+
268
+ json.dump(ans_jsons, open(os.path.expanduser(answer_file), "w"), indent=2)
269
+
270
+ if __name__ == "__main__":
271
+ parser = argparse.ArgumentParser()
272
+ parser.add_argument("--model-path", type=str, required=True)
273
+ parser.add_argument("--num-gpus", type=int, required=True)
274
+ parser.add_argument("--ray-num-gpus", type=int, required=True)
275
+ parser.add_argument("--local-check", action="store_true", help="use faker to generate fake resposne, used for pipeline prompt checking")
276
+ parser.add_argument("--load-in-8bit", action="store_true")
277
+ parser.add_argument("--question-file", type=str, required=True)
278
+ parser.add_argument("--answer-file", type=str, required=True)
279
+ parser.add_argument("--prompt-path", type=str, required=True)
280
+ args = parser.parse_args()
281
+
282
+ run_eval(
283
+ args.model_path,
284
+ args.num_gpus,
285
+ args.local_check,
286
+ args.load_in_8bit,
287
+ args.question_file,
288
+ args.ray_num_gpus,
289
+ args.answer_file,
290
+ args.prompt_path
291
+ )
baselines/MemoChat/code/codes/eval/get_model_infer_simple.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
3
+ from optimum.bettertransformer import BetterTransformer
4
+ import torch
5
+ import os
6
+ import json
7
+ from tqdm import tqdm
8
+ import ray
9
+
10
+ q_pre = "<s>\n"
11
+ qa_link = "\n"
12
+ a_pos = "\n</s>"
13
+ MaxLen = 2048
14
+ TarLen = 512
15
+ TaskTarLen = {
16
+ "chatting_dialogsum": MaxLen,
17
+ "chatting_alpacagpt4": MaxLen,
18
+ "writing_topiocqa": TarLen // 2,
19
+ "writing_dialogsum": TarLen,
20
+ "retrieval_dialogsum": 32,
21
+ "retrieval_topiocqa": 32
22
+ }
23
+
24
+ def get_gpu_memory(ray_num_gpus):
25
+ """Get available memory for each GPU."""
26
+ gpu_memory = []
27
+ for gpu_id in range(ray_num_gpus):
28
+ with torch.cuda.device(gpu_id):
29
+ device = torch.cuda.current_device()
30
+ gpu_properties = torch.cuda.get_device_properties(device)
31
+ total_memory = gpu_properties.total_memory / (1024**3)
32
+ allocated_memory = torch.cuda.memory_allocated() / (1024**3)
33
+ available_memory = total_memory - allocated_memory
34
+ gpu_memory.append(available_memory)
35
+ return gpu_memory
36
+
37
+ def run_eval(model_path, model_id, question_file, answer_file, num_gpus, load_in_8bit, ray_num_gpus):
38
+ assert num_gpus % ray_num_gpus == 0
39
+
40
+ # split question file into num_gpus files
41
+ ques_jsons = []
42
+ with open(os.path.expanduser(question_file), "r") as ques_file:
43
+ for line in ques_file:
44
+ ques_jsons.append(line)
45
+
46
+ chunk_size = len(ques_jsons) // (num_gpus // ray_num_gpus)
47
+ ans_handles = []
48
+ for i in range(0, len(ques_jsons), chunk_size):
49
+ get_answers_func = ray.remote(num_gpus=ray_num_gpus)(
50
+ get_model_answers
51
+ ).remote
52
+ ans_handles.append(
53
+ get_answers_func(
54
+ model_path, model_id, ques_jsons[i: i + chunk_size], ray_num_gpus, load_in_8bit
55
+ )
56
+ )
57
+
58
+ ans_jsons = []
59
+ for ans_handle in ans_handles:
60
+ ans_jsons.extend(ray.get(ans_handle))
61
+
62
+ with open(os.path.expanduser(answer_file), "w") as ans_file:
63
+ for line in ans_jsons:
64
+ ans_file.write(json.dumps(line) + "\n")
65
+
66
+ @torch.inference_mode()
67
+ def get_model_answers(model_path, model_id, question_jsons, ray_num_gpus, load_in_8bit):
68
+ model_path = os.path.expanduser(model_path)
69
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, truncation_side='left')
70
+
71
+ available_gpu_memory = get_gpu_memory(ray_num_gpus)
72
+ gpu_memory_dict = {i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" for i in range(ray_num_gpus)}
73
+ gpu_memory_dict["cpu"] = "0GiB"
74
+
75
+ if "t5" in model_path:
76
+ model = AutoModelForSeq2SeqLM.from_pretrained(
77
+ model_path, torch_dtype=torch.float16, device_map="auto", max_memory=gpu_memory_dict, load_in_8bit=load_in_8bit
78
+ )
79
+ else:
80
+ model = AutoModelForCausalLM.from_pretrained(
81
+ model_path, torch_dtype=torch.float16, device_map="auto", max_memory=gpu_memory_dict, load_in_8bit=load_in_8bit
82
+ )
83
+
84
+ # Initialize with BetterTransformer, injecting Flash-Attention
85
+ model = BetterTransformer.transform(model)
86
+
87
+ # turn on eval mode to stop batch normalizarion & dropout, can work together with torch.inference_mode
88
+ model = model.eval()
89
+
90
+ ans_jsons = []
91
+ for i, line in enumerate(tqdm(question_jsons)):
92
+ ques_json = json.loads(line)
93
+ idx = ques_json["question_id"]
94
+ qs = q_pre + ques_json["text"] + qa_link
95
+
96
+ task_type = ques_json["type"]
97
+ if "t5" in model_path:
98
+ input_ids = tokenizer([qs], max_length=MaxLen, truncation=True, add_special_tokens=False).input_ids
99
+ target_len = TaskTarLen[task_type]
100
+ else:
101
+ input_ids = tokenizer([qs], max_length=(MaxLen - TarLen), truncation=True, add_special_tokens=False).input_ids
102
+ target_len = min(len(input_ids[0]) + TaskTarLen[task_type], MaxLen)
103
+
104
+ output_ids = model.generate(
105
+ torch.as_tensor(input_ids).cuda(),
106
+ do_sample=True,
107
+ temperature=0.2,
108
+ max_length=target_len
109
+ )
110
+ if "t5" in model_path:
111
+ output_ids = output_ids[0]
112
+ else:
113
+ output_ids = output_ids[0][len(input_ids[0]):]
114
+ outputs = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
115
+
116
+ print(outputs)
117
+
118
+ ans_jsons.append(
119
+ {
120
+ "question_id": idx,
121
+ "text": outputs,
122
+ "model_id": model_id,
123
+ "metadata": {},
124
+ }
125
+ )
126
+ return ans_jsons
127
+
128
+ if __name__ == "__main__":
129
+ parser = argparse.ArgumentParser()
130
+ parser.add_argument("--model-path", type=str, required=True)
131
+ parser.add_argument("--model-id", type=str, required=True)
132
+ parser.add_argument("--question-file", type=str, required=True)
133
+ parser.add_argument("--answer-file", type=str, default="answer.jsonl")
134
+ parser.add_argument("--num-gpus", type=int, default=1)
135
+ parser.add_argument("--ray-num-gpus", type=int, default=1)
136
+ parser.add_argument("--load-in-8bit", action="store_true")
137
+ args = parser.parse_args()
138
+
139
+ os.environ["RAY_DEDUP_LOGS"] = "0"
140
+ ray.init(num_gpus=args.num_gpus)
141
+
142
+ run_eval(
143
+ args.model_path,
144
+ args.model_id,
145
+ args.question_file,
146
+ args.answer_file,
147
+ args.num_gpus,
148
+ args.load_in_8bit,
149
+ args.ray_num_gpus
150
+ )
baselines/MemoChat/code/codes/train/data_preprocess.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional
5
+ import torch
6
+ import copy
7
+
8
+ import datasets
9
+ from datasets import load_dataset
10
+
11
+ import transformers
12
+ from transformers import (
13
+ HfArgumentParser,
14
+ T5Tokenizer,
15
+ LlamaTokenizer,
16
+ set_seed,
17
+ )
18
+
19
+ q_pre = "<s>\n"
20
+ qa_link = "\n"
21
+ a_pos = "\n</s>"
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ @dataclass
26
+ class ModelArguments:
27
+ model_name_or_path: str = field(
28
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
29
+ )
30
+
31
+ @dataclass
32
+ class DataTrainingArguments:
33
+ data_path: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
34
+ model_max_length: int = field(
35
+ default=2048,
36
+ metadata={
37
+ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
38
+ },
39
+ )
40
+ preprocessed_path: str = field(
41
+ default=None, metadata={"help": "Path to the preprocessed training data."}
42
+ )
43
+ preprocessing_num_workers: Optional[int] = field(
44
+ default=None,
45
+ metadata={"help": "The number of processes to use for the preprocessing."},
46
+ )
47
+
48
+ def main():
49
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments))
50
+ model_args, data_args = parser.parse_args_into_dataclasses()
51
+
52
+ data_files = {}
53
+ data_files["train"] = data_args.data_path
54
+ raw_datasets = load_dataset(
55
+ "json",
56
+ data_files=data_files
57
+ )
58
+ column_names = raw_datasets["train"].column_names
59
+ print("load dataset finished")
60
+
61
+ if "t5" in model_args.model_name_or_path:
62
+ # use truncation_side='left' to preserve linking between end of prompt and target labels
63
+ tokenizer = T5Tokenizer.from_pretrained(model_args.model_name_or_path, truncation_side='left')
64
+
65
+ def preprocess_function(examples):
66
+ src_inputs = [q_pre + example[0]["value"] + qa_link for example in examples["conversations"]]
67
+ src_model_inputs = tokenizer(src_inputs, max_length=data_args.model_max_length, padding='longest', truncation=True, add_special_tokens=False)
68
+ trg_inputs = [example[1]["value"] + a_pos for example in examples["conversations"]]
69
+ trg_model_inputs = tokenizer(trg_inputs, max_length=data_args.model_max_length, padding='longest', truncation=True, add_special_tokens=False)
70
+ src_model_inputs["labels"] = [
71
+ [(l if l != tokenizer.pad_token_id else label_ignore_id) for l in label] for label in trg_model_inputs["input_ids"]
72
+ ]
73
+ return src_model_inputs
74
+ else:
75
+ # use truncation_side='left' to preserve linking between end of prompt and target labels
76
+ tokenizer = LlamaTokenizer.from_pretrained(model_args.model_name_or_path, truncation_side='left')
77
+
78
+ def preprocess_function(examples):
79
+ inputs = [q_pre + example[0]["value"] + qa_link + example[1]["value"] + a_pos for example in examples["conversations"]]
80
+ model_inputs = tokenizer(inputs, max_length=data_args.model_max_length, padding="longest", truncation=True, add_special_tokens=False)
81
+ model_inputs["labels"] = copy.deepcopy(model_inputs["input_ids"])
82
+ for e_i, example in enumerate(examples["conversations"]):
83
+ source_text = q_pre + example[0]["value"] + qa_link
84
+ target_text = example[1]["value"] + a_pos
85
+ source_ids = tokenizer.encode(source_text, add_special_tokens=False)
86
+ target_ids = tokenizer.encode(target_text, add_special_tokens=False)
87
+ if len(source_ids) >= data_args.model_max_length:
88
+ model_inputs["labels"][e_i] = [label_ignore_id] * data_args.model_max_length
89
+ continue
90
+ else:
91
+ model_inputs["labels"][e_i][:len(source_ids)] = [label_ignore_id] * len(source_ids)
92
+ if len(target_ids) + len(source_ids) >= len(model_inputs["input_ids"][e_i]):
93
+ continue
94
+ else:
95
+ model_inputs["labels"][e_i][(len(target_ids) + len(source_ids)):] = [label_ignore_id] * (len(model_inputs["input_ids"][e_i]) - len(target_ids) - len(source_ids))
96
+ model_inputs["input_ids"] = torch.tensor(model_inputs["input_ids"])
97
+ model_inputs["labels"] = torch.tensor(model_inputs["labels"])
98
+ model_inputs["attention_mask"] = model_inputs["input_ids"].ne(tokenizer.pad_token_id)
99
+ return model_inputs
100
+
101
+ label_ignore_id = -100
102
+
103
+ print("start data preprocess")
104
+ train_dataset = raw_datasets["train"]
105
+ train_dataset = train_dataset.map(
106
+ preprocess_function,
107
+ batched=True,
108
+ batch_size=len(train_dataset),
109
+ remove_columns=column_names,
110
+ num_proc=data_args.preprocessing_num_workers,
111
+ load_from_cache_file=False,
112
+ desc="Running tokenizer on train dataset"
113
+ )
114
+ train_dataset.save_to_disk(data_args.preprocessed_path)
115
+ print("data preprocess finished")
116
+
117
+ if __name__ == "__main__":
118
+ main()
baselines/MemoChat/code/codes/train/train.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional
5
+ import torch
6
+ import copy
7
+ import json
8
+
9
+ import datasets
10
+ from datasets import load_from_disk
11
+
12
+ import transformers
13
+ from transformers import (
14
+ HfArgumentParser,
15
+ T5ForConditionalGeneration,
16
+ T5Tokenizer,
17
+ T5Config,
18
+ LlamaForCausalLM,
19
+ LlamaTokenizer,
20
+ LlamaConfig,
21
+ Trainer,
22
+ TrainingArguments,
23
+ set_seed,
24
+ )
25
+
26
+ from optimum.bettertransformer import BetterTransformer
27
+
28
+ q_pre = "<s>\n"
29
+ qa_link = "\n"
30
+ a_pos = "\n</s>"
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ @dataclass
35
+ class ModelArguments:
36
+ model_name_or_path: str = field(
37
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
38
+ )
39
+
40
+ @dataclass
41
+ class DataTrainingArguments:
42
+ model_max_length: int = field(
43
+ default=2048,
44
+ metadata={
45
+ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
46
+ },
47
+ )
48
+ max_train_samples: Optional[int] = field(
49
+ default=None,
50
+ metadata={
51
+ "help": (
52
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
53
+ "value if set."
54
+ )
55
+ },
56
+ )
57
+ preprocessed_path: str = field(
58
+ default=None, metadata={"help": "Path to the preprocessed training data."}
59
+ )
60
+
61
+ def main():
62
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
63
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
64
+
65
+ # Setup logging
66
+ logging.basicConfig(
67
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
68
+ datefmt="%m/%d/%Y %H:%M:%S",
69
+ handlers=[logging.StreamHandler(sys.stdout)],
70
+ )
71
+
72
+ if training_args.should_log:
73
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
74
+ transformers.utils.logging.set_verbosity_info()
75
+
76
+ log_level = training_args.get_process_log_level()
77
+ logger.setLevel(log_level)
78
+ datasets.utils.logging.set_verbosity(log_level)
79
+ transformers.utils.logging.set_verbosity(log_level)
80
+ transformers.utils.logging.enable_default_handler()
81
+ transformers.utils.logging.enable_explicit_format()
82
+ # Log on each process the small summary:
83
+ logger.warning(
84
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
85
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}"
86
+ )
87
+ logger.info(f"Training/evaluation parameters {training_args}")
88
+
89
+ logger.info("start to load dataset")
90
+ train_dataset = load_from_disk(data_args.preprocessed_path)
91
+ column_names = train_dataset.column_names
92
+ if data_args.max_train_samples is not None:
93
+ max_train_samples = min(len(train_dataset), data_args.max_train_samples)
94
+ train_dataset = train_dataset.select(range(max_train_samples))
95
+ logger.info("load dataset finished")
96
+
97
+ if "t5" in model_args.model_name_or_path:
98
+ # load config and tokenziers
99
+ config = T5Config.from_pretrained(model_args.model_name_or_path)
100
+ config.use_cache=False
101
+ # use truncation_side='left' to preserve linking between end of prompt and target labels
102
+ tokenizer = T5Tokenizer.from_pretrained(model_args.model_name_or_path, truncation_side='left')
103
+ model = T5ForConditionalGeneration.from_pretrained(model_args.model_name_or_path, config=config)
104
+ else:
105
+ # load config and tokenziers
106
+ config = LlamaConfig.from_pretrained(model_args.model_name_or_path)
107
+ config.use_cache=False
108
+ # use truncation_side='left' to preserve linking between end of prompt and target labels
109
+ tokenizer = LlamaTokenizer.from_pretrained(model_args.model_name_or_path, truncation_side='left')
110
+ # initialize modules
111
+ model = LlamaForCausalLM.from_pretrained(model_args.model_name_or_path, config=config)
112
+
113
+ # convert normal model to bettertransformer
114
+ model = BetterTransformer.transform(model)
115
+
116
+ # Setup seed
117
+ set_seed(training_args.seed)
118
+ if len(tokenizer) > tokenizer.vocab_size:
119
+ model.resize_token_embeddings(len(tokenizer))
120
+
121
+ # Setup Trainer
122
+ trainer = Trainer(
123
+ model=model,
124
+ args=training_args,
125
+ train_dataset=train_dataset,
126
+ tokenizer=tokenizer
127
+ )
128
+
129
+ # Training
130
+ train_result = trainer.train()
131
+
132
+ # convert bettertransformer to normal model
133
+ trainer.model = BetterTransformer.reverse(trainer.model)
134
+ trainer.save_state()
135
+
136
+ # save fp16 model under deepspeed zero2 or zero3
137
+ c_stage = json.load(open(training_args.deepspeed, "r"))["zero_optimization"]["stage"]
138
+ if c_stage in [2, 3]:
139
+ if c_stage == 2:
140
+ w_state_dict = trainer.model.state_dict()
141
+ else:
142
+ w_state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
143
+ if trainer.is_world_process_zero():
144
+ state_dict = {key: value.half().cpu() for key, value in w_state_dict.items()}
145
+ trainer._save(training_args.output_dir, state_dict=state_dict)
146
+ else:
147
+ trainer.save_model()
148
+
149
+ if __name__ == "__main__":
150
+ main()
baselines/MemoChat/code/configs/ds_config_13b.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": "auto",
17
+ "betas": "auto",
18
+ "eps": "auto",
19
+ "weight_decay": "auto"
20
+ }
21
+ },
22
+ "scheduler": {
23
+ "type": "WarmupDecayLR",
24
+ "params": {
25
+ "total_num_steps" : "auto",
26
+ "warmup_min_lr": "auto",
27
+ "warmup_max_lr": "auto",
28
+ "warmup_num_steps": "auto"
29
+ }
30
+ },
31
+ "zero_optimization": {
32
+ "stage": 3,
33
+ "offload_optimizer": {
34
+ "device": "cpu",
35
+ "pin_memory": true
36
+ },
37
+ "overlap_comm": true,
38
+ "contiguous_gradients": true,
39
+ "sub_group_size": 1e9,
40
+ "reduce_bucket_size": "auto",
41
+ "stage3_prefetch_bucket_size": "auto",
42
+ "stage3_param_persistence_threshold": "auto",
43
+ "stage3_max_live_parameters": 1e9,
44
+ "stage3_max_reuse_distance": 1e9,
45
+ "stage3_gather_16bit_weights_on_model_save": true
46
+ },
47
+ "gradient_accumulation_steps": "auto",
48
+ "gradient_clipping": "auto",
49
+ "steps_per_print": 2000,
50
+ "train_batch_size": "auto",
51
+ "train_micro_batch_size_per_gpu": "auto",
52
+ "wall_clock_breakdown": false
53
+ }
baselines/MemoChat/code/configs/ds_config_33b.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": "auto",
17
+ "betas": "auto",
18
+ "eps": "auto",
19
+ "weight_decay": "auto"
20
+ }
21
+ },
22
+ "scheduler": {
23
+ "type": "WarmupDecayLR",
24
+ "params": {
25
+ "total_num_steps" : "auto",
26
+ "warmup_min_lr": "auto",
27
+ "warmup_max_lr": "auto",
28
+ "warmup_num_steps": "auto"
29
+ }
30
+ },
31
+ "zero_optimization": {
32
+ "stage": 3,
33
+ "offload_optimizer": {
34
+ "device": "cpu",
35
+ "pin_memory": true
36
+ },
37
+ "offload_param": {
38
+ "device": "cpu",
39
+ "pin_memory": true
40
+ },
41
+ "overlap_comm": true,
42
+ "contiguous_gradients": true,
43
+ "sub_group_size": 1e9,
44
+ "reduce_bucket_size": "auto",
45
+ "stage3_prefetch_bucket_size": "auto",
46
+ "stage3_param_persistence_threshold": "auto",
47
+ "stage3_max_live_parameters": 1e9,
48
+ "stage3_max_reuse_distance": 1e9,
49
+ "stage3_gather_16bit_weights_on_model_save": true
50
+ },
51
+ "gradient_accumulation_steps": "auto",
52
+ "gradient_clipping": "auto",
53
+ "steps_per_print": 2000,
54
+ "train_batch_size": "auto",
55
+ "train_micro_batch_size_per_gpu": "auto",
56
+ "wall_clock_breakdown": false
57
+ }
baselines/MemoChat/code/configs/ds_config_3b.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": "auto",
17
+ "betas": "auto",
18
+ "eps": "auto",
19
+ "weight_decay": "auto"
20
+ }
21
+ },
22
+ "scheduler": {
23
+ "type": "WarmupLR",
24
+ "params": {
25
+ "warmup_min_lr": "auto",
26
+ "warmup_max_lr": "auto",
27
+ "warmup_num_steps": "auto"
28
+ }
29
+ },
30
+ "zero_optimization": {
31
+ "stage": 1
32
+ },
33
+ "gradient_accumulation_steps": "auto",
34
+ "gradient_clipping": "auto",
35
+ "steps_per_print": 2000,
36
+ "train_batch_size": "auto",
37
+ "train_micro_batch_size_per_gpu": "auto",
38
+ "wall_clock_breakdown": false
39
+ }
baselines/MemoChat/code/configs/ds_config_7b.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": "auto",
17
+ "betas": "auto",
18
+ "eps": "auto",
19
+ "weight_decay": "auto"
20
+ }
21
+ },
22
+ "scheduler": {
23
+ "type": "WarmupLR",
24
+ "params": {
25
+ "warmup_min_lr": "auto",
26
+ "warmup_max_lr": "auto",
27
+ "warmup_num_steps": "auto"
28
+ }
29
+ },
30
+ "zero_optimization": {
31
+ "stage": 2,
32
+ "offload_optimizer": {
33
+ "device": "cpu",
34
+ "pin_memory": true
35
+ },
36
+ "allgather_partitions": true,
37
+ "allgather_bucket_size": 2e8,
38
+ "overlap_comm": true,
39
+ "reduce_scatter": true,
40
+ "reduce_bucket_size": 2e8,
41
+ "contiguous_gradients": true
42
+ },
43
+ "gradient_accumulation_steps": "auto",
44
+ "gradient_clipping": "auto",
45
+ "steps_per_print": 2000,
46
+ "train_batch_size": "auto",
47
+ "train_micro_batch_size_per_gpu": "auto",
48
+ "wall_clock_breakdown": false
49
+ }
baselines/MemoChat/code/scripts/llm_judge.sh ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export GLOO_SOCKET_IFNAME=eth0
2
+ export WANDB_MODE=disabled
3
+
4
+ maindir=$1
5
+ datadir=${maindir}data
6
+ codedir=${maindir}code
7
+
8
+ settings=("10k")
9
+ models=("t5-3b" "vicuna-7b" "vicuna-13b" "vicuna-33b")
10
+
11
+ for model in "${models[@]}"
12
+ do
13
+ for setting in "${settings[@]}"
14
+ do
15
+ python3 ${codedir}/codes/api/llm_judge.py \
16
+ ${datadir}/mtbenchplus/mtbenchplus_testing/mtbenchplus_testing_${model}_${setting}.json \
17
+ gpt-4 \
18
+ YourOpenAIKey \
19
+ ${datadir}/llm_judge/llm_judge_gpt-4_${model}_${setting}.json \
20
+ ${datadir}/prompts.json
21
+ done
22
+ done
23
+
24
+ gpt_settings=("2k" "memochat")
25
+
26
+ for gpt_setting in "${gpt_settings[@]}"
27
+ do
28
+ python3 ${codedir}/codes/api/llm_judge.py \
29
+ ${datadir}/mtbenchplus/mtbenchplus_testing/mtbenchplus_testing_gpt-3.5-turbo-${gpt_setting}.json \
30
+ gpt-4 \
31
+ YourOpenAIKey \
32
+ ${datadir}/llm_judge/llm_judge_gpt-4_gpt-3.5-turbo-${gpt_setting}.json \
33
+ ${datadir}/prompts.json
34
+ done
35
+
baselines/MemoChat/code/scripts/memochat.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export GLOO_SOCKET_IFNAME=eth0
2
+ export WANDB_MODE=disabled
3
+
4
+ maindir=$1
5
+ datadir=${maindir}data
6
+ codedir=${maindir}code
7
+
8
+ test_data=${datadir}/mtbenchplus/mtbenchplus.json
9
+
10
+ settings=("1k", "10k")
11
+ models=("t5-3b" "vicuna-7b" "vicuna-13b" "vicuna-33b")
12
+
13
+ for model in "${models[@]}"
14
+ do
15
+ for setting in "${settings[@]}"
16
+ do
17
+ finetuned_model_path=${maindir}model/${model}_${setting}/
18
+ case ${model} in
19
+ "vicuna-33b")
20
+ RAYGPUS=2
21
+ ;;
22
+ "t5-3b"|"vicuna-7b"|"vicuna-13b")
23
+ RAYGPUS=1
24
+ ;;
25
+ esac
26
+ python3 ${codedir}/codes/eval/get_model_infer_memochat.py \
27
+ --model-path ${finetuned_model_path} \
28
+ --question-file ${test_data} \
29
+ --answer-file ${datadir}/mtbenchplus/mtbenchplus_testing/mtbenchplus_testing_${model}_${setting}.json \
30
+ --num-gpus $GPU_NUM_PER_NODE \
31
+ --ray-num-gpus ${RAYGPUS} \
32
+ --prompt-path ${datadir}/prompts.json
33
+ done
34
+ done
baselines/MemoChat/code/scripts/memochat_gpt.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export GLOO_SOCKET_IFNAME=eth0
2
+ export WANDB_MODE=disabled
3
+
4
+ maindir=$1
5
+ datadir=${maindir}data
6
+ codedir=${maindir}code
7
+
8
+ gpt_settings=("2k" "memochat")
9
+
10
+ for gpt_setting in "${gpt_settings[@]}"
11
+ do
12
+ python3 ${codedir}/codes/api/gpt_${gpt_setting}.py \
13
+ ${datadir}/mtbenchplus/mtbenchplus.json \
14
+ gpt-3.5-turbo \
15
+ YourOpenAIKey \
16
+ ${datadir}/mtbenchplus/mtbenchplus_testing/mtbenchplus_testing_gpt-3.5-turbo-${gpt_setting}.json \
17
+ ${datadir}/prompts.json
18
+ done
baselines/MemoChat/code/scripts/tuning.sh ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export GLOO_SOCKET_IFNAME=eth0
2
+ export WANDB_MODE=disabled
3
+
4
+ maindir=$1
5
+ datadir=${maindir}data
6
+ codedir=${maindir}code
7
+
8
+ MAXLEN=2048
9
+ EPOCH=3
10
+ test_data=${datadir}/memochat_instructions/test.jsonl
11
+
12
+ settings=("1k" "10k")
13
+ models=("t5-3b" "vicuna-7b" "vicuna-13b" "vicuna-33b")
14
+
15
+ for model in "${models[@]}"
16
+ do
17
+
18
+ raw_model_path=${maindir}model/fastchat-${model}/
19
+ case ${model} in
20
+ "vicuna-33b")
21
+ RAYGPUS=2
22
+ ;;
23
+ "t5-3b"|"vicuna-7b"|"vicuna-13b")
24
+ RAYGPUS=1
25
+ ;;
26
+ esac
27
+
28
+ # zeroshot inference on one node
29
+ python3 ${codedir}/codes/eval/get_model_infer_simple.py \
30
+ --model-id ${model}_zeroshot \
31
+ --model-path ${raw_model_path} \
32
+ --question-file ${test_data} \
33
+ --answer-file ${datadir}/instruction_testing/instruction_testing_${model}_zeroshot.jsonl \
34
+ --num-gpus $GPU_NUM_PER_NODE \
35
+ --ray-num-gpus ${RAYGPUS}
36
+
37
+ # tuning
38
+ for setting in "${settings[@]}"
39
+ do
40
+ data_path=${datadir}/memochat_instructions/train_${setting}.json
41
+ preprocessed_data_dir=${datadir}/memochat_instructions/processed_${setting}_${model%-*}.pt
42
+ model_output_path=${maindir}model/${model}_${setting}/
43
+ deepspeed_config_path=${codedir}/configs/ds_config_${model#*-}.json
44
+
45
+ case ${model} in
46
+ "t5-3b")
47
+ PER_GPU_BATCH=8
48
+ GRA_ACC=2
49
+ ;;
50
+ "vicuna-7b")
51
+ PER_GPU_BATCH=16
52
+ GRA_ACC=1
53
+ ;;
54
+ "vicuna-13b")
55
+ PER_GPU_BATCH=8
56
+ GRA_ACC=2
57
+ ;;
58
+ "vicuna-33b")
59
+ PER_GPU_BATCH=4
60
+ GRA_ACC=4
61
+ ;;
62
+ esac
63
+
64
+ # train data preprocess
65
+ python3 ${codedir}/codes/train/data_preprocess.py \
66
+ --model_name_or_path ${raw_model_path} \
67
+ --data_path ${data_path} \
68
+ --preprocessing_num_workers=1 \
69
+ --model_max_length ${MAXLEN} \
70
+ --preprocessed_path ${preprocessed_data_dir}
71
+
72
+ # training: avaliable for multi nodes
73
+ torchrun --nnodes=$NODE_NUM \
74
+ --node_rank=$INDEX \
75
+ --nproc_per_node $GPU_NUM_PER_NODE \
76
+ --master_addr $MASTER_ADDR \
77
+ --master_port $MASTER_PORT \
78
+ ${codedir}/codes/train/train.py \
79
+ --model_name_or_path ${raw_model_path} \
80
+ --bf16 True \
81
+ --output_dir ${model_output_path} \
82
+ --num_train_epochs ${EPOCH} \
83
+ --per_device_train_batch_size ${PER_GPU_BATCH} \
84
+ --gradient_accumulation_steps ${GRA_ACC} \
85
+ --save_strategy "steps" \
86
+ --save_steps 1500 \
87
+ --save_total_limit 1 \
88
+ --learning_rate 2e-5 \
89
+ --log_level "info" \
90
+ --logging_strategy "steps" \
91
+ --logging_steps 1 \
92
+ --weight_decay 0. \
93
+ --warmup_ratio 0.04 \
94
+ --lr_scheduler_type "cosine" \
95
+ --deepspeed ${deepspeed_config_path} \
96
+ --tf32 True \
97
+ --model_max_length ${MAXLEN} \
98
+ --preprocessed_path ${preprocessed_data_dir} \
99
+ --gradient_checkpointing True
100
+
101
+ # tuning inference
102
+ python3 ${codedir}/codes/eval/get_model_infer_simple.py \
103
+ --model-id ${model}_${setting} \
104
+ --model-path ${model_output_path} \
105
+ --question-file ${test_data} \
106
+ --answer-file ${datadir}/instruction_testing/instruction_testing_${model}_${setting}.jsonl \
107
+ --num-gpus $GPU_NUM_PER_NODE \
108
+ --ray-num-gpus ${RAYGPUS}
109
+ done
110
+ done
baselines/MemoChat/core_requirement.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.19.0
2
+ datasets==2.10.1
3
+ deepspeed==0.9.4
4
+ evaluate==0.4.0
5
+ Faker==18.11.2
6
+ openai==0.27.2
7
+ optimum==1.9.1
8
+ ray==2.5.1
9
+ tiktoken==0.4.0
10
+ tokenizers==0.13.2
11
+ torch==2.0.1
12
+ torchtext==0.15.2
13
+ transformers==4.29.2
baselines/MemoChat/run_memochat_baseline.py ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MemoChat baseline for the EvolV-Mem benchmark.
3
+
4
+ Adapts MemoChat's three-stage pipeline (memo writing, retrieval, chatting)
5
+ to the EvolV-Mem benchmark using Qwen-30B via vLLM.
6
+
7
+ Pipeline per question:
8
+ 1. Memo writing: extract {topic, summary} from each haystack session (cached)
9
+ 2. Embedding pre-filter: SBert selects top-50 memos by similarity
10
+ 3. MemoChat retrieval: LLM selects final relevant topics from top-50
11
+ 4. Answer generation: LLM generates answer from retrieved memos
12
+
13
+ Usage:
14
+ python baselines/MemoChat/run_memochat_baseline.py \
15
+ --in_file dataset/evolv_mem_v4.json \
16
+ --out_file output/memochat_qwen30b_v4.jsonl \
17
+ --sessions_file dataset/all_sessions.json \
18
+ --profile_file metadata/generated_user_profile.json
19
+
20
+ Env vars:
21
+ VLLM_BASE_URL (default http://localhost:8000/v1)
22
+ VLLM_API_KEY (default EMPTY)
23
+ """
24
+
25
+ import argparse
26
+ import json
27
+ import logging
28
+ import os
29
+ import re
30
+ import sys
31
+ import time
32
+ from collections import defaultdict
33
+ from typing import Dict, List, Optional, Tuple
34
+
35
+ import numpy as np
36
+ from tqdm import tqdm
37
+
38
+ logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # Load MemoChat prompts
42
+ # ---------------------------------------------------------------------------
43
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
44
+ PROMPTS_PATH = os.path.join(SCRIPT_DIR, "data", "prompts.json")
45
+
46
+
47
+ # ---------------------------------------------------------------------------
48
+ # vLLM LLM helper
49
+ # ---------------------------------------------------------------------------
50
+
51
+ def get_llm_client():
52
+ from openai import OpenAI
53
+ return OpenAI(
54
+ base_url=os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1"),
55
+ api_key=os.getenv("VLLM_API_KEY", "EMPTY"),
56
+ )
57
+
58
+
59
+ MODEL_NAME = os.getenv("VLLM_MODEL_NAME", "Qwen/Qwen3-30B-A3B-Instruct-2507")
60
+
61
+
62
+ def llm_call(client, prompt: str, max_tokens: int = 4096, temperature: float = 0.2) -> str:
63
+ """Call the vLLM server with retry logic."""
64
+ for attempt in range(6):
65
+ try:
66
+ response = client.chat.completions.create(
67
+ model=MODEL_NAME,
68
+ messages=[{"role": "user", "content": prompt}],
69
+ max_tokens=max_tokens,
70
+ temperature=temperature,
71
+ )
72
+ content = response.choices[0].message.content if response.choices else None
73
+ if content is None:
74
+ wait = min(2 ** attempt * 2, 30)
75
+ print(f"[WARN] LLM returned None content (attempt {attempt+1}); retrying in {wait}s")
76
+ time.sleep(wait)
77
+ continue
78
+ return content.strip()
79
+ except Exception as e:
80
+ msg = str(e).lower()
81
+ if any(code in msg for code in ("429", "500", "503", "rate limit")):
82
+ wait = min(2 ** attempt * 5, 60)
83
+ print(f"[WARN] LLM retry {attempt+1}/6, sleeping {wait}s: {e}")
84
+ time.sleep(wait)
85
+ continue
86
+ print(f"[ERROR] LLM call failed: {e}")
87
+ raise
88
+ raise RuntimeError("LLM call failed after 6 retries")
89
+
90
+
91
+ # ---------------------------------------------------------------------------
92
+ # MemoChat output parsers (ported from get_model_infer_memochat.py)
93
+ # ---------------------------------------------------------------------------
94
+
95
+ def normalize_model_outputs(model_text: str) -> List[Dict]:
96
+ """Parse memo writing output into structured topic-summary dicts."""
97
+ extracted_elements = [
98
+ re.sub(r'\s+', ' ', mt.replace('"', '').replace("'", ""))
99
+ for mt in re.findall(r"'[^']*'|\"[^\"]*\"|\d+", model_text)
100
+ ]
101
+ model_outputs = []
102
+ ti = 0
103
+ while ti + 7 < len(extracted_elements):
104
+ if (extracted_elements[ti] == "topic"
105
+ and extracted_elements[ti + 2] == "summary"
106
+ and extracted_elements[ti + 4] == "start"
107
+ and extracted_elements[ti + 6] == "end"):
108
+ try:
109
+ model_outputs.append({
110
+ "topic": extracted_elements[ti + 1],
111
+ "summary": extracted_elements[ti + 3],
112
+ "start": int(extracted_elements[ti + 5]),
113
+ "end": int(extracted_elements[ti + 7]),
114
+ })
115
+ except (ValueError, IndexError):
116
+ pass
117
+ ti += 1
118
+ return model_outputs
119
+
120
+
121
+ def normalize_chatting_outputs(model_outputs: str) -> str:
122
+ """Clean up chatting response whitespace."""
123
+ lines = model_outputs.split("\n")
124
+ result = [' '.join(line.split()) for line in lines]
125
+ return '\n'.join(result)
126
+
127
+
128
+ # ---------------------------------------------------------------------------
129
+ # Stage 1: Memo Writing (per session, cached)
130
+ # ---------------------------------------------------------------------------
131
+
132
+ def format_session_for_writing(session_turns: List[Dict]) -> str:
133
+ """Format a session's turns as numbered lines for MemoChat's writing prompt."""
134
+ lines = []
135
+ for i, turn in enumerate(session_turns):
136
+ role = turn.get("role", "user")
137
+ content = turn.get("content", "").replace("\n", " ")
138
+ lines.append(f"(line {i + 1}) {role}: {content}")
139
+ return "\n".join(lines)
140
+
141
+
142
+ def write_memos_for_session(
143
+ session_id: str,
144
+ session_turns: List[Dict],
145
+ client,
146
+ prompts: Dict,
147
+ memo_cache_dir: str,
148
+ ) -> List[Dict]:
149
+ """Extract topic-summary memos from a single session using MemoChat's writing prompt.
150
+
151
+ Returns list of {topic, summary, start, end, session_id}.
152
+ Results are cached to disk.
153
+ """
154
+ cache_path = os.path.join(memo_cache_dir, f"{session_id}.json")
155
+ if os.path.exists(cache_path):
156
+ with open(cache_path) as f:
157
+ return json.load(f)
158
+
159
+ if not session_turns:
160
+ memos = []
161
+ with open(cache_path, "w") as f:
162
+ json.dump(memos, f)
163
+ return memos
164
+
165
+ # Build MemoChat writing prompt
166
+ num_lines = len(session_turns)
167
+ system_instruction = prompts["writing_dialogsum"]["system"]
168
+ task_instruction = prompts["writing_dialogsum"]["instruction"]
169
+
170
+ history_log = "\n\n```\nTask Conversation:\n" + format_session_for_writing(session_turns)
171
+ prompt = (
172
+ system_instruction.replace("LINE", str(num_lines))
173
+ + history_log
174
+ + "\n```"
175
+ + task_instruction.replace("LINE", str(num_lines))
176
+ )
177
+
178
+ # Call LLM
179
+ output = llm_call(client, prompt, max_tokens=512, temperature=0.2)
180
+ memos_raw = normalize_model_outputs(output)
181
+
182
+ # Attach session_id to each memo
183
+ memos = []
184
+ for m in memos_raw:
185
+ memos.append({
186
+ "topic": m["topic"],
187
+ "summary": m["summary"],
188
+ "start": m["start"],
189
+ "end": m["end"],
190
+ "session_id": session_id,
191
+ })
192
+
193
+ # If no memos extracted, create a fallback from the session content
194
+ if not memos:
195
+ # Use first and last turn as a basic summary
196
+ first_content = session_turns[0].get("content", "")[:200] if session_turns else ""
197
+ memos.append({
198
+ "topic": f"session_{session_id}",
199
+ "summary": f"Conversation about: {first_content}...",
200
+ "start": 1,
201
+ "end": num_lines,
202
+ "session_id": session_id,
203
+ })
204
+
205
+ # Cache
206
+ os.makedirs(os.path.dirname(cache_path), exist_ok=True)
207
+ with open(cache_path, "w") as f:
208
+ json.dump(memos, f)
209
+
210
+ return memos
211
+
212
+
213
+ # ---------------------------------------------------------------------------
214
+ # Stage 1 (fast): Build memos from pre-computed session summaries
215
+ # ---------------------------------------------------------------------------
216
+
217
+ def build_memos_from_summaries(
218
+ haystack_session_ids: List[str],
219
+ haystack_dates: List[str],
220
+ summaries: Dict,
221
+ ) -> List[Dict]:
222
+ """Build memo entries directly from all_session_summary.json — no LLM calls."""
223
+ memos = []
224
+ for sid, date_str in zip(haystack_session_ids, haystack_dates):
225
+ summary_data = summaries.get(sid)
226
+ if summary_data is None:
227
+ continue
228
+ text = summary_data.get("session_summary", "")
229
+ if not text:
230
+ turn_sums = summary_data.get("turn_summaries", [])
231
+ if turn_sums:
232
+ text = " ".join(turn_sums)
233
+ else:
234
+ continue
235
+ # Use first ~60 chars as topic, full text as summary
236
+ topic = text[:60].rstrip(". ") if len(text) > 60 else text
237
+ memos.append({
238
+ "topic": topic,
239
+ "summary": text,
240
+ "session_id": sid,
241
+ "session_date": date_str,
242
+ })
243
+ return memos
244
+
245
+
246
+ # ---------------------------------------------------------------------------
247
+ # Stage 2: Embedding pre-filter
248
+ # ---------------------------------------------------------------------------
249
+
250
+ def embed_and_filter(
251
+ question: str,
252
+ all_memos: List[Dict],
253
+ embedding_model,
254
+ top_k: int = 50,
255
+ ) -> List[Dict]:
256
+ """Use SBert to select top-k most relevant memos by cosine similarity."""
257
+ if len(all_memos) <= top_k:
258
+ return all_memos
259
+
260
+ # Build texts to embed
261
+ memo_texts = [f"{m['topic']}. {m['summary']}" for m in all_memos]
262
+
263
+ # Encode
264
+ question_emb = embedding_model.encode(question)
265
+ memo_embs = embedding_model.encode(memo_texts)
266
+
267
+ # Cosine similarity
268
+ question_norm = question_emb / (np.linalg.norm(question_emb) + 1e-10)
269
+ memo_norms = memo_embs / (np.linalg.norm(memo_embs, axis=1, keepdims=True) + 1e-10)
270
+ similarities = memo_norms @ question_norm
271
+
272
+ # Top-k indices
273
+ top_indices = np.argsort(similarities)[::-1][:top_k]
274
+ return [all_memos[i] for i in top_indices]
275
+
276
+
277
+ # ---------------------------------------------------------------------------
278
+ # Stage 3: MemoChat LLM-based retrieval
279
+ # ---------------------------------------------------------------------------
280
+
281
+ def memochat_retrieve(
282
+ question: str,
283
+ candidate_memos: List[Dict],
284
+ client,
285
+ prompts: Dict,
286
+ ) -> List[Dict]:
287
+ """Apply MemoChat's retrieval prompt to select relevant memos from candidates."""
288
+ if not candidate_memos:
289
+ return []
290
+
291
+ # Build topic options list
292
+ topic_options = []
293
+ for i, m in enumerate(candidate_memos):
294
+ topic_options.append(f"({i + 1}) {m['topic']}. {m['summary']}")
295
+
296
+ # Add NOTO option
297
+ noto_idx = len(candidate_memos) + 1
298
+ topic_options.append(f"({noto_idx}) NOTO. None of the others.")
299
+
300
+ system_instruction = prompts["retrieval"]["system"]
301
+ task_instruction = prompts["retrieval"]["instruction"]
302
+
303
+ task_case = (
304
+ "```\nQuery Sentence:\n" + question
305
+ + "\nTopic Options:\n" + "\n".join(topic_options)
306
+ + "\n```"
307
+ )
308
+
309
+ prompt = (
310
+ system_instruction.replace("OPTION", str(len(candidate_memos) + 1))
311
+ + task_case
312
+ + task_instruction.replace("OPTION", str(len(candidate_memos) + 1))
313
+ )
314
+
315
+ output = llm_call(client, prompt, max_tokens=2048, temperature=0.2)
316
+
317
+ # Parse selected indices
318
+ selected_memos = []
319
+ for part in output.split("#"):
320
+ part = part.strip()
321
+ try:
322
+ idx = int(part) - 1
323
+ if 0 <= idx < len(candidate_memos):
324
+ selected_memos.append(candidate_memos[idx])
325
+ except ValueError:
326
+ continue
327
+
328
+ return selected_memos
329
+
330
+
331
+ # ---------------------------------------------------------------------------
332
+ # Stage 4: Answer generation with MemoChat chatting prompt
333
+ # ---------------------------------------------------------------------------
334
+
335
+ def generate_answer(
336
+ question: str,
337
+ question_date: str,
338
+ retrieved_memos: List[Dict],
339
+ user_profile: Optional[str],
340
+ client,
341
+ prompts: Dict,
342
+ ) -> str:
343
+ """Generate an answer using MemoChat's chatting prompt format."""
344
+ system_instruction = prompts["chatting"]["system"]
345
+
346
+ # Build "Related Evidences" section from retrieved memos
347
+ evidence_lines = []
348
+ for i, m in enumerate(retrieved_memos):
349
+ evidence_lines.append(
350
+ f"({i + 1}) {{'Related Topics': '{m['topic']}', "
351
+ f"'Related Summaries': '{m['summary']}'}}"
352
+ )
353
+ evidences_str = "\n".join(evidence_lines) if evidence_lines else "(No related evidences found)"
354
+
355
+ # Build the chatting prompt
356
+ profile_section = ""
357
+ if user_profile:
358
+ profile_section = f"\nUser Profile:\n{user_profile}\n"
359
+
360
+ task_case = (
361
+ f"```\nRelated Evidences:\n{evidences_str}"
362
+ f"\n\nRecent Dialogs:\n(no recent dialogs)"
363
+ f"\n```"
364
+ f"{profile_section}"
365
+ f"\n\nCurrent Date: {question_date}"
366
+ f"\n\nUser Input:\nuser: {question} ### bot: "
367
+ )
368
+
369
+ prompt = system_instruction + task_case
370
+
371
+ output = llm_call(client, prompt, max_tokens=8192, temperature=0.2)
372
+ return normalize_chatting_outputs(output)
373
+
374
+
375
+ # ---------------------------------------------------------------------------
376
+ # Main
377
+ # ---------------------------------------------------------------------------
378
+
379
+ # ---------------------------------------------------------------------------
380
+ # Retrieval metrics
381
+ # ---------------------------------------------------------------------------
382
+
383
+ def evaluate_retrieval(recalled_docs, correct_docs):
384
+ recall_any = float(any(doc in recalled_docs for doc in correct_docs))
385
+ recall_all = float(all(doc in recalled_docs for doc in correct_docs))
386
+ return recall_any, recall_all
387
+
388
+
389
+ def print_average_metrics(retrieval_metric_list):
390
+ metric_sums = defaultdict(float)
391
+ metric_counts = defaultdict(int)
392
+ for metric in retrieval_metric_list:
393
+ for k, v in metric.items():
394
+ metric_sums[k] += v
395
+ metric_counts[k] += 1
396
+ print(" Average retrieval metrics:")
397
+ for k in sorted(metric_sums):
398
+ avg = metric_sums[k] / metric_counts[k]
399
+ print(f" {k}: {avg:.4f}")
400
+
401
+
402
+ # ---------------------------------------------------------------------------
403
+ # Main
404
+ # ---------------------------------------------------------------------------
405
+
406
+ def main():
407
+ parser = argparse.ArgumentParser(description="MemoChat baseline for EvolV-Mem")
408
+ parser.add_argument("--in_file", type=str, required=True,
409
+ help="Path to evolv_mem_v4.json")
410
+ parser.add_argument("--out_file", type=str, required=True,
411
+ help="Output JSONL file")
412
+ parser.add_argument("--sessions_file", type=str, default=None,
413
+ help="Path to all_sessions.json (only needed with --use_llm_memos)")
414
+ parser.add_argument("--summary_file", type=str, default=None,
415
+ help="Path to all_session_summary.json (used by default for memo bank)")
416
+ parser.add_argument("--profile_file", type=str, default=None,
417
+ help="Path to generated_user_profile.json")
418
+ parser.add_argument("--memo_cache_dir", type=str,
419
+ default="baselines/MemoChat/memo_cache",
420
+ help="Directory to cache per-session memos")
421
+ parser.add_argument("--prompt_file", type=str, default=None,
422
+ help="Path to prompts.json (default: baselines/MemoChat/data/prompts.json)")
423
+ # Retrieval params
424
+ parser.add_argument("--embed_top_k", type=int, default=50,
425
+ help="Number of memos to keep after embedding pre-filter (default 50)")
426
+ parser.add_argument("--embedding_model", type=str,
427
+ default="sentence-transformers/multi-qa-mpnet-base-cos-v1",
428
+ help="SentenceTransformer model for embedding pre-filter")
429
+ # Limit (for debugging)
430
+ parser.add_argument("--limit", type=int, default=None,
431
+ help="Process only the first N questions")
432
+ parser.add_argument("--use_llm_memos", action="store_true", default=False,
433
+ help="Use LLM-based memo writing instead of cached session summaries")
434
+ args = parser.parse_args()
435
+
436
+ # -----------------------------------------------------------------------
437
+ # Load data
438
+ # -----------------------------------------------------------------------
439
+ print(f"Loading benchmark from {args.in_file} ...")
440
+ with open(args.in_file) as f:
441
+ benchmark = json.load(f)
442
+ if args.limit:
443
+ benchmark = benchmark[:args.limit]
444
+ print(f" {len(benchmark)} questions loaded.")
445
+
446
+ all_sessions = {}
447
+ if args.sessions_file and os.path.exists(args.sessions_file):
448
+ print(f"Loading sessions from {args.sessions_file} ...")
449
+ with open(args.sessions_file) as f:
450
+ all_sessions = json.load(f)
451
+ print(f" {len(all_sessions)} sessions loaded.")
452
+
453
+ summaries = {}
454
+ if args.summary_file and os.path.exists(args.summary_file):
455
+ print(f"Loading session summaries from {args.summary_file} ...")
456
+ with open(args.summary_file) as f:
457
+ summaries = json.load(f)
458
+ print(f" {len(summaries)} session summaries loaded.")
459
+
460
+ if not args.use_llm_memos and not summaries:
461
+ print("ERROR: --summary_file is required unless --use_llm_memos is set.")
462
+ sys.exit(1)
463
+ if args.use_llm_memos and not all_sessions:
464
+ print("ERROR: --sessions_file is required when --use_llm_memos is set.")
465
+ sys.exit(1)
466
+
467
+ profiles = {}
468
+ if args.profile_file and os.path.exists(args.profile_file):
469
+ print(f"Loading user profiles from {args.profile_file} ...")
470
+ with open(args.profile_file) as f:
471
+ profiles = json.load(f)
472
+ print(f" {len(profiles)} profiles loaded.")
473
+
474
+ prompt_file = args.prompt_file or PROMPTS_PATH
475
+ print(f"Loading prompts from {prompt_file} ...")
476
+ with open(prompt_file) as f:
477
+ prompts = json.load(f)
478
+
479
+ # -----------------------------------------------------------------------
480
+ # Resume support
481
+ # -----------------------------------------------------------------------
482
+ existing_qids = set()
483
+ if os.path.exists(args.out_file):
484
+ with open(args.out_file) as f:
485
+ for line in f:
486
+ line = line.strip()
487
+ if line:
488
+ obj = json.loads(line)
489
+ existing_qids.add(obj["question_id"])
490
+ print(f" Resuming: {len(existing_qids)} questions already processed.")
491
+
492
+ # -----------------------------------------------------------------------
493
+ # Initialize models
494
+ # -----------------------------------------------------------------------
495
+ print("Initializing embedding model ...")
496
+ from sentence_transformers import SentenceTransformer
497
+ embedding_model = SentenceTransformer(args.embedding_model)
498
+
499
+ print("Initializing vLLM client ...")
500
+ client = get_llm_client()
501
+
502
+ os.makedirs(args.memo_cache_dir, exist_ok=True)
503
+
504
+ # -----------------------------------------------------------------------
505
+ # Process questions
506
+ # -----------------------------------------------------------------------
507
+ retrieval_metric_list = []
508
+ out_f = open(args.out_file, "a")
509
+
510
+ for di, entry in enumerate(tqdm(benchmark, desc="MemoChat baseline")):
511
+ qid = entry["question_id"]
512
+ question = entry["question"]
513
+ question_date = entry["question_date"]
514
+
515
+ if qid in existing_qids:
516
+ continue
517
+
518
+ try:
519
+ haystack_session_ids = entry["haystack_session_ids"]
520
+
521
+ # ------ Stage 1: Build memo bank ------
522
+ if args.use_llm_memos:
523
+ # Slow path: LLM-based memo writing (cached per session)
524
+ all_memos = []
525
+ n_cached = 0
526
+ n_written = 0
527
+ date_lookup = dict(zip(
528
+ entry["haystack_session_ids"], entry["haystack_dates"]
529
+ ))
530
+ for sid in haystack_session_ids:
531
+ session_turns = all_sessions.get(sid, [])
532
+ cache_exists = os.path.exists(
533
+ os.path.join(args.memo_cache_dir, f"{sid}.json")
534
+ )
535
+ memos = write_memos_for_session(
536
+ sid, session_turns, client, prompts, args.memo_cache_dir
537
+ )
538
+ for m in memos:
539
+ m["session_date"] = date_lookup.get(sid, "")
540
+ all_memos.extend(memos)
541
+ if cache_exists:
542
+ n_cached += 1
543
+ else:
544
+ n_written += 1
545
+ print(f" [{di}] qid={qid}: {len(all_memos)} memos "
546
+ f"({n_cached} cached, {n_written} new)")
547
+ else:
548
+ # Fast path: use pre-computed session summaries as memos
549
+ all_memos = build_memos_from_summaries(
550
+ haystack_session_ids, entry["haystack_dates"], summaries
551
+ )
552
+ print(f" [{di}] qid={qid}: {len(all_memos)} memos from summaries")
553
+
554
+ if not all_memos:
555
+ result = {
556
+ "q_idx": di,
557
+ "question_id": qid,
558
+ "hypothesis": "Insufficient information to answer.",
559
+ "n_memos": 0,
560
+ }
561
+ print(json.dumps(result), file=out_f, flush=True)
562
+ continue
563
+
564
+ # ------ Stage 2: Embedding pre-filter ------
565
+ filtered_memos = embed_and_filter(
566
+ question, all_memos, embedding_model, top_k=args.embed_top_k
567
+ )
568
+ print(f" [{di}] Embedding filter: {len(all_memos)} -> {len(filtered_memos)} memos")
569
+
570
+ # ------ Stage 3: MemoChat LLM retrieval ------
571
+ retrieved_memos = memochat_retrieve(
572
+ question, filtered_memos, client, prompts
573
+ )
574
+ print(f" [{di}] MemoChat retrieval: {len(filtered_memos)} -> {len(retrieved_memos)} memos")
575
+
576
+ # Fallback: if retrieval selected nothing, use top-5 from embedding filter
577
+ if not retrieved_memos:
578
+ retrieved_memos = filtered_memos[:5]
579
+ print(f" [{di}] Fallback: using top-5 from embedding filter")
580
+
581
+ # ------ Stage 4: Answer generation ------
582
+ user_id = qid.split("_q_")[0] if "_q_" in qid else qid
583
+ user_profile = profiles.get(user_id, None)
584
+
585
+ answer = generate_answer(
586
+ question, question_date, retrieved_memos,
587
+ user_profile, client, prompts
588
+ )
589
+
590
+ # ------ Output ------
591
+ retrieved_session_ids = list(dict.fromkeys(
592
+ m["session_id"] for m in retrieved_memos if "session_id" in m
593
+ ))
594
+
595
+ # Compute retrieval metrics
596
+ answer_session_ids = entry.get("answer_session_ids", [])
597
+ retrieval_metric = {}
598
+ if answer_session_ids and retrieved_session_ids:
599
+ for topk in [5, 10, 20, 30]:
600
+ r_any, r_all = evaluate_retrieval(
601
+ retrieved_session_ids[:topk], answer_session_ids
602
+ )
603
+ retrieval_metric[f"recall_any@{topk}"] = r_any
604
+ retrieval_metric[f"recall_all@{topk}"] = r_all
605
+ retrieval_metric_list.append(retrieval_metric)
606
+ print_average_metrics(retrieval_metric_list)
607
+
608
+ result = {
609
+ "q_idx": di,
610
+ "question_id": qid,
611
+ "hypothesis": answer,
612
+ "n_memos_total": len(all_memos),
613
+ "n_memos_filtered": len(filtered_memos),
614
+ "n_memos_retrieved": len(retrieved_memos),
615
+ "retrieved_session_ids": retrieved_session_ids,
616
+ "retrieval_metric": retrieval_metric,
617
+ }
618
+ print(json.dumps(result), file=out_f, flush=True)
619
+
620
+ print(f" [{di}] Q: {question[:100]}...")
621
+ print(f" [{di}] A: {answer[:200]}...")
622
+
623
+ except Exception as e:
624
+ print(f"[ERROR] q_idx={di} qid={qid} failed: {e}", flush=True)
625
+ import traceback
626
+ traceback.print_exc()
627
+ continue
628
+
629
+ out_f.close()
630
+ print(f"\nDone. Results saved to {args.out_file}")
631
+
632
+
633
+ if __name__ == "__main__":
634
+ main()
baselines/raptor/LICENSE.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The MIT License
2
+
3
+ Copyright (c) Parth Sarthi
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in
13
+ all copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21
+ THE SOFTWARE.
baselines/raptor/README.md ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- <p align="center">
2
+ <img align="center" src="raptor.jpg" width="1000px" />
3
+ </p>
4
+ <p align="left"> -->
5
+
6
+ <!-- <picture>
7
+ <source media="(prefers-color-scheme: dark)" srcset="raptor.jpg" width="1000px">
8
+ <source media="(prefers-color-scheme: light)" srcset="raptor_dark.png" width="1000px">
9
+
10
+ </picture> -->
11
+
12
+ <picture>
13
+ <source media="(prefers-color-scheme: dark)" srcset="raptor_dark.png">
14
+ <img alt="Shows an illustrated sun in light color mode and a moon with stars in dark color mode." src="raptor.jpg">
15
+ </picture>
16
+
17
+ ## RAPTOR: Recursive Abstractive Processing for Tree-Organized Retrieval
18
+
19
+ **RAPTOR** introduces a novel approach to retrieval-augmented language models by constructing a recursive tree structure from documents. This allows for more efficient and context-aware information retrieval across large texts, addressing common limitations in traditional language models.
20
+
21
+
22
+
23
+ For detailed methodologies and implementations, refer to the original paper:
24
+
25
+ - [RAPTOR: Recursive Abstractive Processing for Tree-Organized Retrieval](https://arxiv.org/abs/2401.18059)
26
+
27
+ [![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-sm.svg)](https://huggingface.co/papers/2401.18059)
28
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/raptor-recursive-abstractive-processing-for/question-answering-on-quality)](https://paperswithcode.com/sota/question-answering-on-quality?p=raptor-recursive-abstractive-processing-for)
29
+
30
+ ## Installation
31
+
32
+ Before using RAPTOR, ensure Python 3.8+ is installed. Clone the RAPTOR repository and install necessary dependencies:
33
+
34
+ ```bash
35
+ git clone https://github.com/parthsarthi03/raptor.git
36
+ cd raptor
37
+ pip install -r requirements.txt
38
+ ```
39
+
40
+ ## Basic Usage
41
+
42
+ To get started with RAPTOR, follow these steps:
43
+
44
+ ### Setting Up RAPTOR
45
+
46
+ First, set your OpenAI API key and initialize the RAPTOR configuration:
47
+
48
+ ```python
49
+ import os
50
+ os.environ["OPENAI_API_KEY"] = "your-openai-api-key"
51
+
52
+ from raptor import RetrievalAugmentation
53
+
54
+ # Initialize with default configuration. For advanced configurations, check the documentation. [WIP]
55
+ RA = RetrievalAugmentation()
56
+ ```
57
+
58
+ ### Adding Documents to the Tree
59
+
60
+ Add your text documents to RAPTOR for indexing:
61
+
62
+ ```python
63
+ with open('sample.txt', 'r') as file:
64
+ text = file.read()
65
+ RA.add_documents(text)
66
+ ```
67
+
68
+ ### Answering Questions
69
+
70
+ You can now use RAPTOR to answer questions based on the indexed documents:
71
+
72
+ ```python
73
+ question = "How did Cinderella reach her happy ending?"
74
+ answer = RA.answer_question(question=question)
75
+ print("Answer: ", answer)
76
+ ```
77
+
78
+ ### Saving and Loading the Tree
79
+
80
+ Save the constructed tree to a specified path:
81
+
82
+ ```python
83
+ SAVE_PATH = "demo/cinderella"
84
+ RA.save(SAVE_PATH)
85
+ ```
86
+
87
+ Load the saved tree back into RAPTOR:
88
+
89
+ ```python
90
+ RA = RetrievalAugmentation(tree=SAVE_PATH)
91
+ answer = RA.answer_question(question=question)
92
+ ```
93
+
94
+
95
+ ### Extending RAPTOR with other Models
96
+
97
+ RAPTOR is designed to be flexible and allows you to integrate any models for summarization, question-answering (QA), and embedding generation. Here is how to extend RAPTOR with your own models:
98
+
99
+ #### Custom Summarization Model
100
+
101
+ If you wish to use a different language model for summarization, you can do so by extending the `BaseSummarizationModel` class. Implement the `summarize` method to integrate your custom summarization logic:
102
+
103
+ ```python
104
+ from raptor import BaseSummarizationModel
105
+
106
+ class CustomSummarizationModel(BaseSummarizationModel):
107
+ def __init__(self):
108
+ # Initialize your model here
109
+ pass
110
+
111
+ def summarize(self, context, max_tokens=150):
112
+ # Implement your summarization logic here
113
+ # Return the summary as a string
114
+ summary = "Your summary here"
115
+ return summary
116
+ ```
117
+
118
+ #### Custom QA Model
119
+
120
+ For custom QA models, extend the `BaseQAModel` class and implement the `answer_question` method. This method should return the best answer found by your model given a context and a question:
121
+
122
+ ```python
123
+ from raptor import BaseQAModel
124
+
125
+ class CustomQAModel(BaseQAModel):
126
+ def __init__(self):
127
+ # Initialize your model here
128
+ pass
129
+
130
+ def answer_question(self, context, question):
131
+ # Implement your QA logic here
132
+ # Return the answer as a string
133
+ answer = "Your answer here"
134
+ return answer
135
+ ```
136
+
137
+ #### Custom Embedding Model
138
+
139
+ To use a different embedding model, extend the `BaseEmbeddingModel` class. Implement the `create_embedding` method, which should return a vector representation of the input text:
140
+
141
+ ```python
142
+ from raptor import BaseEmbeddingModel
143
+
144
+ class CustomEmbeddingModel(BaseEmbeddingModel):
145
+ def __init__(self):
146
+ # Initialize your model here
147
+ pass
148
+
149
+ def create_embedding(self, text):
150
+ # Implement your embedding logic here
151
+ # Return the embedding as a numpy array or a list of floats
152
+ embedding = [0.0] * embedding_dim # Replace with actual embedding logic
153
+ return embedding
154
+ ```
155
+
156
+ #### Integrating Custom Models with RAPTOR
157
+
158
+ After implementing your custom models, integrate them with RAPTOR as follows:
159
+
160
+ ```python
161
+ from raptor import RetrievalAugmentation, RetrievalAugmentationConfig
162
+
163
+ # Initialize your custom models
164
+ custom_summarizer = CustomSummarizationModel()
165
+ custom_qa = CustomQAModel()
166
+ custom_embedding = CustomEmbeddingModel()
167
+
168
+ # Create a config with your custom models
169
+ custom_config = RetrievalAugmentationConfig(
170
+ summarization_model=custom_summarizer,
171
+ qa_model=custom_qa,
172
+ embedding_model=custom_embedding
173
+ )
174
+
175
+ # Initialize RAPTOR with your custom config
176
+ RA = RetrievalAugmentation(config=custom_config)
177
+ ```
178
+
179
+ Check out `demo.ipynb` for examples on how to specify your own summarization/QA models, such as Llama/Mistral/Gemma, and Embedding Models such as SBERT, for use with RAPTOR.
180
+
181
+ Note: More examples and ways to configure RAPTOR are forthcoming. Advanced usage and additional features will be provided in the documentation and repository updates.
182
+
183
+ ## Contributing
184
+
185
+ RAPTOR is an open-source project, and contributions are welcome. Whether you're fixing bugs, adding new features, or improving documentation, your help is appreciated.
186
+
187
+ ## License
188
+
189
+ RAPTOR is released under the MIT License. See the LICENSE file in the repository for full details.
190
+
191
+ ## Citation
192
+
193
+ If RAPTOR assists in your research, please cite it as follows:
194
+
195
+ ```bibtex
196
+ @inproceedings{sarthi2024raptor,
197
+ title={RAPTOR: Recursive Abstractive Processing for Tree-Organized Retrieval},
198
+ author={Sarthi, Parth and Abdullah, Salman and Tuli, Aditi and Khanna, Shubh and Goldie, Anna and Manning, Christopher D.},
199
+ booktitle={International Conference on Learning Representations (ICLR)},
200
+ year={2024}
201
+ }
202
+ ```
203
+
204
+ Stay tuned for more examples, configuration guides, and updates.
baselines/raptor/raptor/EmbeddingModels.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+
4
+ from openai import OpenAI
5
+ from sentence_transformers import SentenceTransformer
6
+ from tenacity import retry, stop_after_attempt, wait_random_exponential
7
+
8
+ logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
9
+
10
+
11
+ class BaseEmbeddingModel(ABC):
12
+ @abstractmethod
13
+ def create_embedding(self, text):
14
+ pass
15
+
16
+
17
+ class OpenAIEmbeddingModel(BaseEmbeddingModel):
18
+ def __init__(self, model="text-embedding-ada-002"):
19
+ self.client = OpenAI()
20
+ self.model = model
21
+
22
+ @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
23
+ def create_embedding(self, text):
24
+ text = text.replace("\n", " ")
25
+ return (
26
+ self.client.embeddings.create(input=[text], model=self.model)
27
+ .data[0]
28
+ .embedding
29
+ )
30
+
31
+
32
+ class SBertEmbeddingModel(BaseEmbeddingModel):
33
+ def __init__(self, model_name="sentence-transformers/multi-qa-mpnet-base-cos-v1"):
34
+ self.model = SentenceTransformer(model_name)
35
+
36
+ def create_embedding(self, text):
37
+ return self.model.encode(text)
baselines/raptor/raptor/FaissRetriever.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from concurrent.futures import ProcessPoolExecutor
3
+
4
+ import faiss
5
+ import numpy as np
6
+ import tiktoken
7
+ from tqdm import tqdm
8
+
9
+ from .EmbeddingModels import BaseEmbeddingModel, OpenAIEmbeddingModel
10
+ from .Retrievers import BaseRetriever
11
+ from .utils import split_text
12
+
13
+
14
+ class FaissRetrieverConfig:
15
+ def __init__(
16
+ self,
17
+ max_tokens=100,
18
+ max_context_tokens=3500,
19
+ use_top_k=False,
20
+ embedding_model=None,
21
+ question_embedding_model=None,
22
+ top_k=5,
23
+ tokenizer=tiktoken.get_encoding("cl100k_base"),
24
+ embedding_model_string=None,
25
+ ):
26
+ if max_tokens < 1:
27
+ raise ValueError("max_tokens must be at least 1")
28
+
29
+ if top_k < 1:
30
+ raise ValueError("top_k must be at least 1")
31
+
32
+ if max_context_tokens is not None and max_context_tokens < 1:
33
+ raise ValueError("max_context_tokens must be at least 1 or None")
34
+
35
+ if embedding_model is not None and not isinstance(
36
+ embedding_model, BaseEmbeddingModel
37
+ ):
38
+ raise ValueError(
39
+ "embedding_model must be an instance of BaseEmbeddingModel or None"
40
+ )
41
+
42
+ if question_embedding_model is not None and not isinstance(
43
+ question_embedding_model, BaseEmbeddingModel
44
+ ):
45
+ raise ValueError(
46
+ "question_embedding_model must be an instance of BaseEmbeddingModel or None"
47
+ )
48
+
49
+ self.top_k = top_k
50
+ self.max_tokens = max_tokens
51
+ self.max_context_tokens = max_context_tokens
52
+ self.use_top_k = use_top_k
53
+ self.embedding_model = embedding_model or OpenAIEmbeddingModel()
54
+ self.question_embedding_model = question_embedding_model or self.embedding_model
55
+ self.tokenizer = tokenizer
56
+ self.embedding_model_string = embedding_model_string or "OpenAI"
57
+
58
+ def log_config(self):
59
+ config_summary = """
60
+ FaissRetrieverConfig:
61
+ Max Tokens: {max_tokens}
62
+ Max Context Tokens: {max_context_tokens}
63
+ Use Top K: {use_top_k}
64
+ Embedding Model: {embedding_model}
65
+ Question Embedding Model: {question_embedding_model}
66
+ Top K: {top_k}
67
+ Tokenizer: {tokenizer}
68
+ Embedding Model String: {embedding_model_string}
69
+ """.format(
70
+ max_tokens=self.max_tokens,
71
+ max_context_tokens=self.max_context_tokens,
72
+ use_top_k=self.use_top_k,
73
+ embedding_model=self.embedding_model,
74
+ question_embedding_model=self.question_embedding_model,
75
+ top_k=self.top_k,
76
+ tokenizer=self.tokenizer,
77
+ embedding_model_string=self.embedding_model_string,
78
+ )
79
+ return config_summary
80
+
81
+
82
+ class FaissRetriever(BaseRetriever):
83
+ """
84
+ FaissRetriever is a class that retrieves similar context chunks for a given query using Faiss.
85
+ encoders_type is 'same' if the question and context encoder is the same,
86
+ otherwise, encoders_type is 'different'.
87
+ """
88
+
89
+ def __init__(self, config):
90
+ self.embedding_model = config.embedding_model
91
+ self.question_embedding_model = config.question_embedding_model
92
+ self.index = None
93
+ self.context_chunks = None
94
+ self.max_tokens = config.max_tokens
95
+ self.max_context_tokens = config.max_context_tokens
96
+ self.use_top_k = config.use_top_k
97
+ self.tokenizer = config.tokenizer
98
+ self.top_k = config.top_k
99
+ self.embedding_model_string = config.embedding_model_string
100
+
101
+ def build_from_text(self, doc_text):
102
+ """
103
+ Builds the index from a given text.
104
+
105
+ :param doc_text: A string containing the document text.
106
+ :param tokenizer: A tokenizer used to split the text into chunks.
107
+ :param max_tokens: An integer representing the maximum number of tokens per chunk.
108
+ """
109
+ self.context_chunks = np.array(
110
+ split_text(doc_text, self.tokenizer, self.max_tokens)
111
+ )
112
+
113
+ with ProcessPoolExecutor() as executor:
114
+ futures = [
115
+ executor.submit(self.embedding_model.create_embedding, context_chunk)
116
+ for context_chunk in self.context_chunks
117
+ ]
118
+
119
+ self.embeddings = []
120
+ for future in tqdm(futures, total=len(futures), desc="Building embeddings"):
121
+ self.embeddings.append(future.result())
122
+
123
+ self.embeddings = np.array(self.embeddings, dtype=np.float32)
124
+
125
+ self.index = faiss.IndexFlatIP(self.embeddings.shape[1])
126
+ self.index.add(self.embeddings)
127
+
128
+ def build_from_leaf_nodes(self, leaf_nodes):
129
+ """
130
+ Builds the index from a given text.
131
+
132
+ :param doc_text: A string containing the document text.
133
+ :param tokenizer: A tokenizer used to split the text into chunks.
134
+ :param max_tokens: An integer representing the maximum number of tokens per chunk.
135
+ """
136
+
137
+ self.context_chunks = [node.text for node in leaf_nodes]
138
+
139
+ self.embeddings = np.array(
140
+ [node.embeddings[self.embedding_model_string] for node in leaf_nodes],
141
+ dtype=np.float32,
142
+ )
143
+
144
+ self.index = faiss.IndexFlatIP(self.embeddings.shape[1])
145
+ self.index.add(self.embeddings)
146
+
147
+ def sanity_check(self, num_samples=4):
148
+ """
149
+ Perform a sanity check by recomputing embeddings of a few randomly-selected chunks.
150
+
151
+ :param num_samples: The number of samples to test.
152
+ """
153
+ indices = random.sample(range(len(self.context_chunks)), num_samples)
154
+
155
+ for i in indices:
156
+ original_embedding = self.embeddings[i]
157
+ recomputed_embedding = self.embedding_model.create_embedding(
158
+ self.context_chunks[i]
159
+ )
160
+ assert np.allclose(
161
+ original_embedding, recomputed_embedding
162
+ ), f"Embeddings do not match for index {i}!"
163
+
164
+ print(f"Sanity check passed for {num_samples} random samples.")
165
+
166
+ def retrieve(self, query: str) -> str:
167
+ """
168
+ Retrieves the k most similar context chunks for a given query.
169
+
170
+ :param query: A string containing the query.
171
+ :param k: An integer representing the number of similar context chunks to retrieve.
172
+ :return: A string containing the retrieved context chunks.
173
+ """
174
+ query_embedding = np.array(
175
+ [
176
+ np.array(
177
+ self.question_embedding_model.create_embedding(query),
178
+ dtype=np.float32,
179
+ ).squeeze()
180
+ ]
181
+ )
182
+
183
+ context = ""
184
+
185
+ if self.use_top_k:
186
+ _, indices = self.index.search(query_embedding, self.top_k)
187
+ for i in range(self.top_k):
188
+ context += self.context_chunks[indices[0][i]]
189
+
190
+ else:
191
+ range_ = int(self.max_context_tokens / self.max_tokens)
192
+ _, indices = self.index.search(query_embedding, range_)
193
+ total_tokens = 0
194
+ for i in range(range_):
195
+ tokens = len(self.tokenizer.encode(self.context_chunks[indices[0][i]]))
196
+ context += self.context_chunks[indices[0][i]]
197
+ if total_tokens + tokens > self.max_context_tokens:
198
+ break
199
+ total_tokens += tokens
200
+
201
+ return context
baselines/raptor/raptor/QAModels.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ from openai import OpenAI
5
+
6
+
7
+ import getpass
8
+ from abc import ABC, abstractmethod
9
+
10
+ import torch
11
+ from tenacity import retry, stop_after_attempt, wait_random_exponential
12
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
13
+
14
+
15
+ class BaseQAModel(ABC):
16
+ @abstractmethod
17
+ def answer_question(self, context, question):
18
+ pass
19
+
20
+
21
+ class GPT3QAModel(BaseQAModel):
22
+ def __init__(self, model="text-davinci-003"):
23
+ """
24
+ Initializes the GPT-3 model with the specified model version.
25
+
26
+ Args:
27
+ model (str, optional): The GPT-3 model version to use for generating summaries. Defaults to "text-davinci-003".
28
+ """
29
+ self.model = model
30
+ self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
31
+
32
+ @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
33
+ def answer_question(self, context, question, max_tokens=150, stop_sequence=None):
34
+ """
35
+ Generates a summary of the given context using the GPT-3 model.
36
+
37
+ Args:
38
+ context (str): The text to summarize.
39
+ max_tokens (int, optional): The maximum number of tokens in the generated summary. Defaults to 150.
40
+ stop_sequence (str, optional): The sequence at which to stop summarization. Defaults to None.
41
+
42
+ Returns:
43
+ str: The generated summary.
44
+ """
45
+ try:
46
+ response = self.client.completions.create(
47
+ prompt=f"using the folloing information {context}. Answer the following question in less than 5-7 words, if possible: {question}",
48
+ temperature=0,
49
+ max_tokens=max_tokens,
50
+ top_p=1,
51
+ frequency_penalty=0,
52
+ presence_penalty=0,
53
+ stop=stop_sequence,
54
+ model=self.model,
55
+ )
56
+ return response.choices[0].text.strip()
57
+
58
+ except Exception as e:
59
+ print(e)
60
+ return ""
61
+
62
+
63
+ class GPT3TurboQAModel(BaseQAModel):
64
+ def __init__(self, model="gpt-3.5-turbo"):
65
+ """
66
+ Initializes the GPT-3 model with the specified model version.
67
+
68
+ Args:
69
+ model (str, optional): The GPT-3 model version to use for generating summaries. Defaults to "text-davinci-003".
70
+ """
71
+ self.model = model
72
+ self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
73
+
74
+ @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
75
+ def _attempt_answer_question(
76
+ self, context, question, max_tokens=150, stop_sequence=None
77
+ ):
78
+ """
79
+ Generates a summary of the given context using the GPT-3 model.
80
+
81
+ Args:
82
+ context (str): The text to summarize.
83
+ max_tokens (int, optional): The maximum number of tokens in the generated summary. Defaults to 150.
84
+ stop_sequence (str, optional): The sequence at which to stop summarization. Defaults to None.
85
+
86
+ Returns:
87
+ str: The generated summary.
88
+ """
89
+ response = self.client.chat.completions.create(
90
+ model=self.model,
91
+ messages=[
92
+ {"role": "system", "content": "You are Question Answering Portal"},
93
+ {
94
+ "role": "user",
95
+ "content": f"Given Context: {context} Give the best full answer amongst the option to question {question}",
96
+ },
97
+ ],
98
+ temperature=0,
99
+ )
100
+
101
+ return response.choices[0].message.content.strip()
102
+
103
+ @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
104
+ def answer_question(self, context, question, max_tokens=150, stop_sequence=None):
105
+
106
+ try:
107
+ return self._attempt_answer_question(
108
+ context, question, max_tokens=max_tokens, stop_sequence=stop_sequence
109
+ )
110
+ except Exception as e:
111
+ print(e)
112
+ return e
113
+
114
+
115
+ class GPT4QAModel(BaseQAModel):
116
+ def __init__(self, model="gpt-4"):
117
+ """
118
+ Initializes the GPT-3 model with the specified model version.
119
+
120
+ Args:
121
+ model (str, optional): The GPT-3 model version to use for generating summaries. Defaults to "text-davinci-003".
122
+ """
123
+ self.model = model
124
+ self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
125
+
126
+ @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
127
+ def _attempt_answer_question(
128
+ self, context, question, max_tokens=150, stop_sequence=None
129
+ ):
130
+ """
131
+ Generates a summary of the given context using the GPT-3 model.
132
+
133
+ Args:
134
+ context (str): The text to summarize.
135
+ max_tokens (int, optional): The maximum number of tokens in the generated summary. Defaults to 150.
136
+ stop_sequence (str, optional): The sequence at which to stop summarization. Defaults to None.
137
+
138
+ Returns:
139
+ str: The generated summary.
140
+ """
141
+ response = self.client.chat.completions.create(
142
+ model=self.model,
143
+ messages=[
144
+ {"role": "system", "content": "You are Question Answering Portal"},
145
+ {
146
+ "role": "user",
147
+ "content": f"Given Context: {context} Give the best full answer amongst the option to question {question}",
148
+ },
149
+ ],
150
+ temperature=0,
151
+ )
152
+
153
+ return response.choices[0].message.content.strip()
154
+
155
+ @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
156
+ def answer_question(self, context, question, max_tokens=150, stop_sequence=None):
157
+
158
+ try:
159
+ return self._attempt_answer_question(
160
+ context, question, max_tokens=max_tokens, stop_sequence=stop_sequence
161
+ )
162
+ except Exception as e:
163
+ print(e)
164
+ return e
165
+
166
+
167
+ class UnifiedQAModel(BaseQAModel):
168
+ def __init__(self, model_name="allenai/unifiedqa-v2-t5-3b-1363200"):
169
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
170
+ self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(
171
+ self.device
172
+ )
173
+ self.tokenizer = T5Tokenizer.from_pretrained(model_name)
174
+
175
+ def run_model(self, input_string, **generator_args):
176
+ input_ids = self.tokenizer.encode(input_string, return_tensors="pt").to(
177
+ self.device
178
+ )
179
+ res = self.model.generate(input_ids, **generator_args)
180
+ return self.tokenizer.batch_decode(res, skip_special_tokens=True)
181
+
182
+ def answer_question(self, context, question):
183
+ input_string = question + " \\n " + context
184
+ output = self.run_model(input_string)
185
+ return output[0]
baselines/raptor/raptor/RetrievalAugmentation.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import pickle
3
+
4
+ from .cluster_tree_builder import ClusterTreeBuilder, ClusterTreeConfig
5
+ from .EmbeddingModels import BaseEmbeddingModel
6
+ from .QAModels import BaseQAModel, GPT3TurboQAModel
7
+ from .SummarizationModels import BaseSummarizationModel
8
+ from .tree_builder import TreeBuilder, TreeBuilderConfig
9
+ from .tree_retriever import TreeRetriever, TreeRetrieverConfig
10
+ from .tree_structures import Node, Tree
11
+
12
+ # Define a dictionary to map supported tree builders to their respective configs
13
+ supported_tree_builders = {"cluster": (ClusterTreeBuilder, ClusterTreeConfig)}
14
+
15
+ logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
16
+
17
+
18
+ class RetrievalAugmentationConfig:
19
+ def __init__(
20
+ self,
21
+ tree_builder_config=None,
22
+ tree_retriever_config=None, # Change from default instantiation
23
+ qa_model=None,
24
+ embedding_model=None,
25
+ summarization_model=None,
26
+ tree_builder_type="cluster",
27
+ # New parameters for TreeRetrieverConfig and TreeBuilderConfig
28
+ # TreeRetrieverConfig arguments
29
+ tr_tokenizer=None,
30
+ tr_threshold=0.5,
31
+ tr_top_k=5,
32
+ tr_selection_mode="top_k",
33
+ tr_context_embedding_model="OpenAI",
34
+ tr_embedding_model=None,
35
+ tr_num_layers=None,
36
+ tr_start_layer=None,
37
+ # TreeBuilderConfig arguments
38
+ tb_tokenizer=None,
39
+ tb_max_tokens=100,
40
+ tb_num_layers=5,
41
+ tb_threshold=0.5,
42
+ tb_top_k=5,
43
+ tb_selection_mode="top_k",
44
+ tb_summarization_length=100,
45
+ tb_summarization_model=None,
46
+ tb_embedding_models=None,
47
+ tb_cluster_embedding_model="OpenAI",
48
+ ):
49
+ # Validate tree_builder_type
50
+ if tree_builder_type not in supported_tree_builders:
51
+ raise ValueError(
52
+ f"tree_builder_type must be one of {list(supported_tree_builders.keys())}"
53
+ )
54
+
55
+ # Validate qa_model
56
+ if qa_model is not None and not isinstance(qa_model, BaseQAModel):
57
+ raise ValueError("qa_model must be an instance of BaseQAModel")
58
+
59
+ if embedding_model is not None and not isinstance(
60
+ embedding_model, BaseEmbeddingModel
61
+ ):
62
+ raise ValueError(
63
+ "embedding_model must be an instance of BaseEmbeddingModel"
64
+ )
65
+ elif embedding_model is not None:
66
+ if tb_embedding_models is not None:
67
+ raise ValueError(
68
+ "Only one of 'tb_embedding_models' or 'embedding_model' should be provided, not both."
69
+ )
70
+ tb_embedding_models = {"EMB": embedding_model}
71
+ tr_embedding_model = embedding_model
72
+ tb_cluster_embedding_model = "EMB"
73
+ tr_context_embedding_model = "EMB"
74
+
75
+ if summarization_model is not None and not isinstance(
76
+ summarization_model, BaseSummarizationModel
77
+ ):
78
+ raise ValueError(
79
+ "summarization_model must be an instance of BaseSummarizationModel"
80
+ )
81
+
82
+ elif summarization_model is not None:
83
+ if tb_summarization_model is not None:
84
+ raise ValueError(
85
+ "Only one of 'tb_summarization_model' or 'summarization_model' should be provided, not both."
86
+ )
87
+ tb_summarization_model = summarization_model
88
+
89
+ # Set TreeBuilderConfig
90
+ tree_builder_class, tree_builder_config_class = supported_tree_builders[
91
+ tree_builder_type
92
+ ]
93
+ if tree_builder_config is None:
94
+ tree_builder_config = tree_builder_config_class(
95
+ tokenizer=tb_tokenizer,
96
+ max_tokens=tb_max_tokens,
97
+ num_layers=tb_num_layers,
98
+ threshold=tb_threshold,
99
+ top_k=tb_top_k,
100
+ selection_mode=tb_selection_mode,
101
+ summarization_length=tb_summarization_length,
102
+ summarization_model=tb_summarization_model,
103
+ embedding_models=tb_embedding_models,
104
+ cluster_embedding_model=tb_cluster_embedding_model,
105
+ )
106
+
107
+ elif not isinstance(tree_builder_config, tree_builder_config_class):
108
+ raise ValueError(
109
+ f"tree_builder_config must be a direct instance of {tree_builder_config_class} for tree_builder_type '{tree_builder_type}'"
110
+ )
111
+
112
+ # Set TreeRetrieverConfig
113
+ if tree_retriever_config is None:
114
+ tree_retriever_config = TreeRetrieverConfig(
115
+ tokenizer=tr_tokenizer,
116
+ threshold=tr_threshold,
117
+ top_k=tr_top_k,
118
+ selection_mode=tr_selection_mode,
119
+ context_embedding_model=tr_context_embedding_model,
120
+ embedding_model=tr_embedding_model,
121
+ num_layers=tr_num_layers,
122
+ start_layer=tr_start_layer,
123
+ )
124
+ elif not isinstance(tree_retriever_config, TreeRetrieverConfig):
125
+ raise ValueError(
126
+ "tree_retriever_config must be an instance of TreeRetrieverConfig"
127
+ )
128
+
129
+ # Assign the created configurations to the instance
130
+ self.tree_builder_config = tree_builder_config
131
+ self.tree_retriever_config = tree_retriever_config
132
+ self.qa_model = qa_model or GPT3TurboQAModel()
133
+ self.tree_builder_type = tree_builder_type
134
+
135
+ def log_config(self):
136
+ config_summary = """
137
+ RetrievalAugmentationConfig:
138
+ {tree_builder_config}
139
+
140
+ {tree_retriever_config}
141
+
142
+ QA Model: {qa_model}
143
+ Tree Builder Type: {tree_builder_type}
144
+ """.format(
145
+ tree_builder_config=self.tree_builder_config.log_config(),
146
+ tree_retriever_config=self.tree_retriever_config.log_config(),
147
+ qa_model=self.qa_model,
148
+ tree_builder_type=self.tree_builder_type,
149
+ )
150
+ return config_summary
151
+
152
+
153
+ class RetrievalAugmentation:
154
+ """
155
+ A Retrieval Augmentation class that combines the TreeBuilder and TreeRetriever classes.
156
+ Enables adding documents to the tree, retrieving information, and answering questions.
157
+ """
158
+
159
+ def __init__(self, config=None, tree=None):
160
+ """
161
+ Initializes a RetrievalAugmentation instance with the specified configuration.
162
+ Args:
163
+ config (RetrievalAugmentationConfig): The configuration for the RetrievalAugmentation instance.
164
+ tree: The tree instance or the path to a pickled tree file.
165
+ """
166
+ if config is None:
167
+ config = RetrievalAugmentationConfig()
168
+ if not isinstance(config, RetrievalAugmentationConfig):
169
+ raise ValueError(
170
+ "config must be an instance of RetrievalAugmentationConfig"
171
+ )
172
+
173
+ # Check if tree is a string (indicating a path to a pickled tree)
174
+ if isinstance(tree, str):
175
+ try:
176
+ with open(tree, "rb") as file:
177
+ self.tree = pickle.load(file)
178
+ if not isinstance(self.tree, Tree):
179
+ raise ValueError("The loaded object is not an instance of Tree")
180
+ except Exception as e:
181
+ raise ValueError(f"Failed to load tree from {tree}: {e}")
182
+ elif isinstance(tree, Tree) or tree is None:
183
+ self.tree = tree
184
+ else:
185
+ raise ValueError(
186
+ "tree must be an instance of Tree, a path to a pickled Tree, or None"
187
+ )
188
+
189
+ tree_builder_class = supported_tree_builders[config.tree_builder_type][0]
190
+ self.tree_builder = tree_builder_class(config.tree_builder_config)
191
+
192
+ self.tree_retriever_config = config.tree_retriever_config
193
+ self.qa_model = config.qa_model
194
+
195
+ if self.tree is not None:
196
+ self.retriever = TreeRetriever(self.tree_retriever_config, self.tree)
197
+ else:
198
+ self.retriever = None
199
+
200
+ logging.info(
201
+ f"Successfully initialized RetrievalAugmentation with Config {config.log_config()}"
202
+ )
203
+
204
+ def add_documents(self, docs):
205
+ """
206
+ Adds documents to the tree and creates a TreeRetriever instance.
207
+
208
+ Args:
209
+ docs (str): The input text to add to the tree.
210
+ """
211
+ if self.tree is not None:
212
+ user_input = input(
213
+ "Warning: Overwriting existing tree. Did you mean to call 'add_to_existing' instead? (y/n): "
214
+ )
215
+ if user_input.lower() == "y":
216
+ # self.add_to_existing(docs)
217
+ return
218
+
219
+ self.tree = self.tree_builder.build_from_text(text=docs)
220
+ self.retriever = TreeRetriever(self.tree_retriever_config, self.tree)
221
+
222
+ def retrieve(
223
+ self,
224
+ question,
225
+ start_layer: int = None,
226
+ num_layers: int = None,
227
+ top_k: int = 10,
228
+ max_tokens: int = 3500,
229
+ collapse_tree: bool = True,
230
+ return_layer_information: bool = True,
231
+ ):
232
+ """
233
+ Retrieves information and answers a question using the TreeRetriever instance.
234
+
235
+ Args:
236
+ question (str): The question to answer.
237
+ start_layer (int): The layer to start from. Defaults to self.start_layer.
238
+ num_layers (int): The number of layers to traverse. Defaults to self.num_layers.
239
+ max_tokens (int): The maximum number of tokens. Defaults to 3500.
240
+ use_all_information (bool): Whether to retrieve information from all nodes. Defaults to False.
241
+
242
+ Returns:
243
+ str: The context from which the answer can be found.
244
+
245
+ Raises:
246
+ ValueError: If the TreeRetriever instance has not been initialized.
247
+ """
248
+ if self.retriever is None:
249
+ raise ValueError(
250
+ "The TreeRetriever instance has not been initialized. Call 'add_documents' first."
251
+ )
252
+
253
+ return self.retriever.retrieve(
254
+ question,
255
+ start_layer,
256
+ num_layers,
257
+ top_k,
258
+ max_tokens,
259
+ collapse_tree,
260
+ return_layer_information,
261
+ )
262
+
263
+ def answer_question(
264
+ self,
265
+ question,
266
+ top_k: int = 10,
267
+ start_layer: int = None,
268
+ num_layers: int = None,
269
+ max_tokens: int = 3500,
270
+ collapse_tree: bool = True,
271
+ return_layer_information: bool = False,
272
+ ):
273
+ """
274
+ Retrieves information and answers a question using the TreeRetriever instance.
275
+
276
+ Args:
277
+ question (str): The question to answer.
278
+ start_layer (int): The layer to start from. Defaults to self.start_layer.
279
+ num_layers (int): The number of layers to traverse. Defaults to self.num_layers.
280
+ max_tokens (int): The maximum number of tokens. Defaults to 3500.
281
+ use_all_information (bool): Whether to retrieve information from all nodes. Defaults to False.
282
+
283
+ Returns:
284
+ str: The answer to the question.
285
+
286
+ Raises:
287
+ ValueError: If the TreeRetriever instance has not been initialized.
288
+ """
289
+ # if return_layer_information:
290
+ context, layer_information = self.retrieve(
291
+ question, start_layer, num_layers, top_k, max_tokens, collapse_tree, True
292
+ )
293
+
294
+ answer = self.qa_model.answer_question(context, question)
295
+
296
+ if return_layer_information:
297
+ return answer, layer_information
298
+
299
+ return answer
300
+
301
+ def save(self, path):
302
+ if self.tree is None:
303
+ raise ValueError("There is no tree to save.")
304
+ with open(path, "wb") as file:
305
+ pickle.dump(self.tree, file)
306
+ logging.info(f"Tree successfully saved to {path}")
baselines/raptor/raptor/Retrievers.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List
3
+
4
+
5
+ class BaseRetriever(ABC):
6
+ @abstractmethod
7
+ def retrieve(self, query: str) -> str:
8
+ pass
baselines/raptor/raptor/SummarizationModels.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from abc import ABC, abstractmethod
4
+
5
+ from openai import OpenAI
6
+ from tenacity import retry, stop_after_attempt, wait_random_exponential
7
+
8
+ logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
9
+
10
+
11
+ class BaseSummarizationModel(ABC):
12
+ @abstractmethod
13
+ def summarize(self, context, max_tokens=150):
14
+ pass
15
+
16
+
17
+ class GPT3TurboSummarizationModel(BaseSummarizationModel):
18
+ def __init__(self, model="gpt-3.5-turbo"):
19
+
20
+ self.model = model
21
+
22
+ @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
23
+ def summarize(self, context, max_tokens=500, stop_sequence=None):
24
+
25
+ try:
26
+ client = OpenAI()
27
+
28
+ response = client.chat.completions.create(
29
+ model=self.model,
30
+ messages=[
31
+ {"role": "system", "content": "You are a helpful assistant."},
32
+ {
33
+ "role": "user",
34
+ "content": f"Write a summary of the following, including as many key details as possible: {context}:",
35
+ },
36
+ ],
37
+ max_tokens=max_tokens,
38
+ )
39
+
40
+ return response.choices[0].message.content
41
+
42
+ except Exception as e:
43
+ print(e)
44
+ return e
45
+
46
+
47
+ class GPT3SummarizationModel(BaseSummarizationModel):
48
+ def __init__(self, model="text-davinci-003"):
49
+
50
+ self.model = model
51
+
52
+ @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
53
+ def summarize(self, context, max_tokens=500, stop_sequence=None):
54
+
55
+ try:
56
+ client = OpenAI()
57
+
58
+ response = client.chat.completions.create(
59
+ model=self.model,
60
+ messages=[
61
+ {"role": "system", "content": "You are a helpful assistant."},
62
+ {
63
+ "role": "user",
64
+ "content": f"Write a summary of the following, including as many key details as possible: {context}:",
65
+ },
66
+ ],
67
+ max_tokens=max_tokens,
68
+ )
69
+
70
+ return response.choices[0].message.content
71
+
72
+ except Exception as e:
73
+ print(e)
74
+ return e
baselines/raptor/raptor/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # raptor/__init__.py
2
+ from .cluster_tree_builder import ClusterTreeBuilder, ClusterTreeConfig
3
+ from .EmbeddingModels import (BaseEmbeddingModel, OpenAIEmbeddingModel,
4
+ SBertEmbeddingModel)
5
+ from .FaissRetriever import FaissRetriever, FaissRetrieverConfig
6
+ from .QAModels import (BaseQAModel, GPT3QAModel, GPT3TurboQAModel, GPT4QAModel,
7
+ UnifiedQAModel)
8
+ from .RetrievalAugmentation import (RetrievalAugmentation,
9
+ RetrievalAugmentationConfig)
10
+ from .Retrievers import BaseRetriever
11
+ from .SummarizationModels import (BaseSummarizationModel,
12
+ GPT3SummarizationModel,
13
+ GPT3TurboSummarizationModel)
14
+ from .tree_builder import TreeBuilder, TreeBuilderConfig
15
+ from .tree_retriever import TreeRetriever, TreeRetrieverConfig
16
+ from .tree_structures import Node, Tree
baselines/raptor/raptor/cluster_tree_builder.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import pickle
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from threading import Lock
5
+ from typing import Dict, List, Set
6
+
7
+ from .cluster_utils import ClusteringAlgorithm, RAPTOR_Clustering
8
+ from .tree_builder import TreeBuilder, TreeBuilderConfig
9
+ from .tree_structures import Node, Tree
10
+ from .utils import (distances_from_embeddings, get_children, get_embeddings,
11
+ get_node_list, get_text,
12
+ indices_of_nearest_neighbors_from_distances, split_text)
13
+
14
+ logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
15
+
16
+
17
+ class ClusterTreeConfig(TreeBuilderConfig):
18
+ def __init__(
19
+ self,
20
+ reduction_dimension=10,
21
+ clustering_algorithm=RAPTOR_Clustering, # Default to RAPTOR clustering
22
+ clustering_params={}, # Pass additional params as a dict
23
+ *args,
24
+ **kwargs,
25
+ ):
26
+ super().__init__(*args, **kwargs)
27
+ self.reduction_dimension = reduction_dimension
28
+ self.clustering_algorithm = clustering_algorithm
29
+ self.clustering_params = clustering_params
30
+
31
+ def log_config(self):
32
+ base_summary = super().log_config()
33
+ cluster_tree_summary = f"""
34
+ Reduction Dimension: {self.reduction_dimension}
35
+ Clustering Algorithm: {self.clustering_algorithm.__name__}
36
+ Clustering Parameters: {self.clustering_params}
37
+ """
38
+ return base_summary + cluster_tree_summary
39
+
40
+
41
+ class ClusterTreeBuilder(TreeBuilder):
42
+ def __init__(self, config) -> None:
43
+ super().__init__(config)
44
+
45
+ if not isinstance(config, ClusterTreeConfig):
46
+ raise ValueError("config must be an instance of ClusterTreeConfig")
47
+ self.reduction_dimension = config.reduction_dimension
48
+ self.clustering_algorithm = config.clustering_algorithm
49
+ self.clustering_params = config.clustering_params
50
+
51
+ logging.info(
52
+ f"Successfully initialized ClusterTreeBuilder with Config {config.log_config()}"
53
+ )
54
+
55
+ def construct_tree(
56
+ self,
57
+ current_level_nodes: Dict[int, Node],
58
+ all_tree_nodes: Dict[int, Node],
59
+ layer_to_nodes: Dict[int, List[Node]],
60
+ use_multithreading: bool = False,
61
+ ) -> Dict[int, Node]:
62
+ logging.info("Using Cluster TreeBuilder")
63
+
64
+ next_node_index = len(all_tree_nodes)
65
+
66
+ def process_cluster(
67
+ cluster, new_level_nodes, next_node_index, summarization_length, lock
68
+ ):
69
+ node_texts = get_text(cluster)
70
+
71
+ summarized_text = self.summarize(
72
+ context=node_texts,
73
+ max_tokens=summarization_length,
74
+ )
75
+
76
+ logging.info(
77
+ f"Node Texts Length: {len(self.tokenizer.encode(node_texts))}, Summarized Text Length: {len(self.tokenizer.encode(summarized_text))}"
78
+ )
79
+
80
+ __, new_parent_node = self.create_node(
81
+ next_node_index, summarized_text, {node.index for node in cluster}
82
+ )
83
+
84
+ with lock:
85
+ new_level_nodes[next_node_index] = new_parent_node
86
+
87
+ for layer in range(self.num_layers):
88
+
89
+ new_level_nodes = {}
90
+
91
+ logging.info(f"Constructing Layer {layer}")
92
+
93
+ node_list_current_layer = get_node_list(current_level_nodes)
94
+
95
+ if len(node_list_current_layer) <= self.reduction_dimension + 1:
96
+ self.num_layers = layer
97
+ logging.info(
98
+ f"Stopping Layer construction: Cannot Create More Layers. Total Layers in tree: {layer}"
99
+ )
100
+ break
101
+
102
+ clusters = self.clustering_algorithm.perform_clustering(
103
+ node_list_current_layer,
104
+ self.cluster_embedding_model,
105
+ reduction_dimension=self.reduction_dimension,
106
+ **self.clustering_params,
107
+ )
108
+
109
+ lock = Lock()
110
+
111
+ summarization_length = self.summarization_length
112
+ logging.info(f"Summarization Length: {summarization_length}")
113
+
114
+ if use_multithreading:
115
+ with ThreadPoolExecutor() as executor:
116
+ for cluster in clusters:
117
+ executor.submit(
118
+ process_cluster,
119
+ cluster,
120
+ new_level_nodes,
121
+ next_node_index,
122
+ summarization_length,
123
+ lock,
124
+ )
125
+ next_node_index += 1
126
+ executor.shutdown(wait=True)
127
+
128
+ else:
129
+ for cluster in clusters:
130
+ process_cluster(
131
+ cluster,
132
+ new_level_nodes,
133
+ next_node_index,
134
+ summarization_length,
135
+ lock,
136
+ )
137
+ next_node_index += 1
138
+
139
+ layer_to_nodes[layer + 1] = list(new_level_nodes.values())
140
+ current_level_nodes = new_level_nodes
141
+ all_tree_nodes.update(new_level_nodes)
142
+
143
+ tree = Tree(
144
+ all_tree_nodes,
145
+ layer_to_nodes[layer + 1],
146
+ layer_to_nodes[0],
147
+ layer + 1,
148
+ layer_to_nodes,
149
+ )
150
+
151
+ return current_level_nodes
baselines/raptor/raptor/cluster_utils.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ from abc import ABC, abstractmethod
4
+ from typing import List, Optional
5
+
6
+ import numpy as np
7
+ import tiktoken
8
+ import umap
9
+ from sklearn.mixture import GaussianMixture
10
+
11
+ # Initialize logging
12
+ logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
13
+
14
+ from .tree_structures import Node
15
+ # Import necessary methods from other modules
16
+ from .utils import get_embeddings
17
+
18
+ # Set a random seed for reproducibility
19
+ RANDOM_SEED = 224
20
+ random.seed(RANDOM_SEED)
21
+
22
+
23
+ def global_cluster_embeddings(
24
+ embeddings: np.ndarray,
25
+ dim: int,
26
+ n_neighbors: Optional[int] = None,
27
+ metric: str = "cosine",
28
+ ) -> np.ndarray:
29
+ if n_neighbors is None:
30
+ n_neighbors = int((len(embeddings) - 1) ** 0.5)
31
+ reduced_embeddings = umap.UMAP(
32
+ n_neighbors=n_neighbors, n_components=dim, metric=metric
33
+ ).fit_transform(embeddings)
34
+ return reduced_embeddings
35
+
36
+
37
+ def local_cluster_embeddings(
38
+ embeddings: np.ndarray, dim: int, num_neighbors: int = 10, metric: str = "cosine"
39
+ ) -> np.ndarray:
40
+ reduced_embeddings = umap.UMAP(
41
+ n_neighbors=num_neighbors, n_components=dim, metric=metric
42
+ ).fit_transform(embeddings)
43
+ return reduced_embeddings
44
+
45
+
46
+ def get_optimal_clusters(
47
+ embeddings: np.ndarray, max_clusters: int = 50, random_state: int = RANDOM_SEED
48
+ ) -> int:
49
+ max_clusters = min(max_clusters, len(embeddings))
50
+ n_clusters = np.arange(1, max_clusters)
51
+ bics = []
52
+ for n in n_clusters:
53
+ gm = GaussianMixture(n_components=n, random_state=random_state)
54
+ gm.fit(embeddings)
55
+ bics.append(gm.bic(embeddings))
56
+ optimal_clusters = n_clusters[np.argmin(bics)]
57
+ return optimal_clusters
58
+
59
+
60
+ def GMM_cluster(embeddings: np.ndarray, threshold: float, random_state: int = 0):
61
+ n_clusters = get_optimal_clusters(embeddings)
62
+ gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
63
+ gm.fit(embeddings)
64
+ probs = gm.predict_proba(embeddings)
65
+ labels = [np.where(prob > threshold)[0] for prob in probs]
66
+ return labels, n_clusters
67
+
68
+
69
+ def perform_clustering(
70
+ embeddings: np.ndarray, dim: int, threshold: float, verbose: bool = False
71
+ ) -> List[np.ndarray]:
72
+ reduced_embeddings_global = global_cluster_embeddings(embeddings, min(dim, len(embeddings) -2))
73
+ global_clusters, n_global_clusters = GMM_cluster(
74
+ reduced_embeddings_global, threshold
75
+ )
76
+
77
+ if verbose:
78
+ logging.info(f"Global Clusters: {n_global_clusters}")
79
+
80
+ all_local_clusters = [np.array([]) for _ in range(len(embeddings))]
81
+ total_clusters = 0
82
+
83
+ for i in range(n_global_clusters):
84
+ global_cluster_embeddings_ = embeddings[
85
+ np.array([i in gc for gc in global_clusters])
86
+ ]
87
+ if verbose:
88
+ logging.info(
89
+ f"Nodes in Global Cluster {i}: {len(global_cluster_embeddings_)}"
90
+ )
91
+ if len(global_cluster_embeddings_) == 0:
92
+ continue
93
+ if len(global_cluster_embeddings_) <= dim + 1:
94
+ local_clusters = [np.array([0]) for _ in global_cluster_embeddings_]
95
+ n_local_clusters = 1
96
+ else:
97
+ reduced_embeddings_local = local_cluster_embeddings(
98
+ global_cluster_embeddings_, dim
99
+ )
100
+ local_clusters, n_local_clusters = GMM_cluster(
101
+ reduced_embeddings_local, threshold
102
+ )
103
+
104
+ if verbose:
105
+ logging.info(f"Local Clusters in Global Cluster {i}: {n_local_clusters}")
106
+
107
+ for j in range(n_local_clusters):
108
+ local_cluster_embeddings_ = global_cluster_embeddings_[
109
+ np.array([j in lc for lc in local_clusters])
110
+ ]
111
+ indices = np.where(
112
+ (embeddings == local_cluster_embeddings_[:, None]).all(-1)
113
+ )[1]
114
+ for idx in indices:
115
+ all_local_clusters[idx] = np.append(
116
+ all_local_clusters[idx], j + total_clusters
117
+ )
118
+
119
+ total_clusters += n_local_clusters
120
+
121
+ if verbose:
122
+ logging.info(f"Total Clusters: {total_clusters}")
123
+ return all_local_clusters
124
+
125
+
126
+ class ClusteringAlgorithm(ABC):
127
+ @abstractmethod
128
+ def perform_clustering(self, embeddings: np.ndarray, **kwargs) -> List[List[int]]:
129
+ pass
130
+
131
+
132
+ class RAPTOR_Clustering(ClusteringAlgorithm):
133
+ def perform_clustering(
134
+ nodes: List[Node],
135
+ embedding_model_name: str,
136
+ max_length_in_cluster: int = 3500,
137
+ tokenizer=tiktoken.get_encoding("cl100k_base"),
138
+ reduction_dimension: int = 10,
139
+ threshold: float = 0.1,
140
+ verbose: bool = False,
141
+ ) -> List[List[Node]]:
142
+ # Get the embeddings from the nodes
143
+ embeddings = np.array([node.embeddings[embedding_model_name] for node in nodes])
144
+
145
+ # Perform the clustering
146
+ clusters = perform_clustering(
147
+ embeddings, dim=reduction_dimension, threshold=threshold
148
+ )
149
+
150
+ # Initialize an empty list to store the clusters of nodes
151
+ node_clusters = []
152
+
153
+ # Iterate over each unique label in the clusters
154
+ for label in np.unique(np.concatenate(clusters)):
155
+ # Get the indices of the nodes that belong to this cluster
156
+ indices = [i for i, cluster in enumerate(clusters) if label in cluster]
157
+
158
+ # Add the corresponding nodes to the node_clusters list
159
+ cluster_nodes = [nodes[i] for i in indices]
160
+
161
+ # Base case: if the cluster only has one node, do not attempt to recluster it
162
+ if len(cluster_nodes) == 1:
163
+ node_clusters.append(cluster_nodes)
164
+ continue
165
+
166
+ # Calculate the total length of the text in the nodes
167
+ total_length = sum(
168
+ [len(tokenizer.encode(node.text)) for node in cluster_nodes]
169
+ )
170
+
171
+ # If the total length exceeds the maximum allowed length, recluster this cluster
172
+ if total_length > max_length_in_cluster:
173
+ if verbose:
174
+ logging.info(
175
+ f"reclustering cluster with {len(cluster_nodes)} nodes"
176
+ )
177
+ node_clusters.extend(
178
+ RAPTOR_Clustering.perform_clustering(
179
+ cluster_nodes, embedding_model_name, max_length_in_cluster
180
+ )
181
+ )
182
+ else:
183
+ node_clusters.append(cluster_nodes)
184
+
185
+ return node_clusters
baselines/raptor/raptor/tree_builder.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import os
4
+ from abc import abstractclassmethod
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+ from threading import Lock
7
+ from typing import Dict, List, Optional, Set, Tuple
8
+
9
+ import openai
10
+ import tiktoken
11
+ from tenacity import retry, stop_after_attempt, wait_random_exponential
12
+
13
+ from .EmbeddingModels import BaseEmbeddingModel, OpenAIEmbeddingModel
14
+ from .SummarizationModels import (BaseSummarizationModel,
15
+ GPT3TurboSummarizationModel)
16
+ from .tree_structures import Node, Tree
17
+ from .utils import (distances_from_embeddings, get_children, get_embeddings,
18
+ get_node_list, get_text,
19
+ indices_of_nearest_neighbors_from_distances, split_text)
20
+
21
+ logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
22
+
23
+
24
+ class TreeBuilderConfig:
25
+ def __init__(
26
+ self,
27
+ tokenizer=None,
28
+ max_tokens=None,
29
+ num_layers=None,
30
+ threshold=None,
31
+ top_k=None,
32
+ selection_mode=None,
33
+ summarization_length=None,
34
+ summarization_model=None,
35
+ embedding_models=None,
36
+ cluster_embedding_model=None,
37
+ ):
38
+ if tokenizer is None:
39
+ tokenizer = tiktoken.get_encoding("cl100k_base")
40
+ self.tokenizer = tokenizer
41
+
42
+ if max_tokens is None:
43
+ max_tokens = 100
44
+ if not isinstance(max_tokens, int) or max_tokens < 1:
45
+ raise ValueError("max_tokens must be an integer and at least 1")
46
+ self.max_tokens = max_tokens
47
+
48
+ if num_layers is None:
49
+ num_layers = 5
50
+ if not isinstance(num_layers, int) or num_layers < 1:
51
+ raise ValueError("num_layers must be an integer and at least 1")
52
+ self.num_layers = num_layers
53
+
54
+ if threshold is None:
55
+ threshold = 0.5
56
+ if not isinstance(threshold, (int, float)) or not (0 <= threshold <= 1):
57
+ raise ValueError("threshold must be a number between 0 and 1")
58
+ self.threshold = threshold
59
+
60
+ if top_k is None:
61
+ top_k = 5
62
+ if not isinstance(top_k, int) or top_k < 1:
63
+ raise ValueError("top_k must be an integer and at least 1")
64
+ self.top_k = top_k
65
+
66
+ if selection_mode is None:
67
+ selection_mode = "top_k"
68
+ if selection_mode not in ["top_k", "threshold"]:
69
+ raise ValueError("selection_mode must be either 'top_k' or 'threshold'")
70
+ self.selection_mode = selection_mode
71
+
72
+ if summarization_length is None:
73
+ summarization_length = 100
74
+ self.summarization_length = summarization_length
75
+
76
+ if summarization_model is None:
77
+ summarization_model = GPT3TurboSummarizationModel()
78
+ if not isinstance(summarization_model, BaseSummarizationModel):
79
+ raise ValueError(
80
+ "summarization_model must be an instance of BaseSummarizationModel"
81
+ )
82
+ self.summarization_model = summarization_model
83
+
84
+ if embedding_models is None:
85
+ embedding_models = {"OpenAI": OpenAIEmbeddingModel()}
86
+ if not isinstance(embedding_models, dict):
87
+ raise ValueError(
88
+ "embedding_models must be a dictionary of model_name: instance pairs"
89
+ )
90
+ for model in embedding_models.values():
91
+ if not isinstance(model, BaseEmbeddingModel):
92
+ raise ValueError(
93
+ "All embedding models must be an instance of BaseEmbeddingModel"
94
+ )
95
+ self.embedding_models = embedding_models
96
+
97
+ if cluster_embedding_model is None:
98
+ cluster_embedding_model = "OpenAI"
99
+ if cluster_embedding_model not in self.embedding_models:
100
+ raise ValueError(
101
+ "cluster_embedding_model must be a key in the embedding_models dictionary"
102
+ )
103
+ self.cluster_embedding_model = cluster_embedding_model
104
+
105
+ def log_config(self):
106
+ config_log = """
107
+ TreeBuilderConfig:
108
+ Tokenizer: {tokenizer}
109
+ Max Tokens: {max_tokens}
110
+ Num Layers: {num_layers}
111
+ Threshold: {threshold}
112
+ Top K: {top_k}
113
+ Selection Mode: {selection_mode}
114
+ Summarization Length: {summarization_length}
115
+ Summarization Model: {summarization_model}
116
+ Embedding Models: {embedding_models}
117
+ Cluster Embedding Model: {cluster_embedding_model}
118
+ """.format(
119
+ tokenizer=self.tokenizer,
120
+ max_tokens=self.max_tokens,
121
+ num_layers=self.num_layers,
122
+ threshold=self.threshold,
123
+ top_k=self.top_k,
124
+ selection_mode=self.selection_mode,
125
+ summarization_length=self.summarization_length,
126
+ summarization_model=self.summarization_model,
127
+ embedding_models=self.embedding_models,
128
+ cluster_embedding_model=self.cluster_embedding_model,
129
+ )
130
+ return config_log
131
+
132
+
133
+ class TreeBuilder:
134
+ """
135
+ The TreeBuilder class is responsible for building a hierarchical text abstraction
136
+ structure, known as a "tree," using summarization models and
137
+ embedding models.
138
+ """
139
+
140
+ def __init__(self, config) -> None:
141
+ """Initializes the tokenizer, maximum tokens, number of layers, top-k value, threshold, and selection mode."""
142
+
143
+ self.tokenizer = config.tokenizer
144
+ self.max_tokens = config.max_tokens
145
+ self.num_layers = config.num_layers
146
+ self.top_k = config.top_k
147
+ self.threshold = config.threshold
148
+ self.selection_mode = config.selection_mode
149
+ self.summarization_length = config.summarization_length
150
+ self.summarization_model = config.summarization_model
151
+ self.embedding_models = config.embedding_models
152
+ self.cluster_embedding_model = config.cluster_embedding_model
153
+
154
+ logging.info(
155
+ f"Successfully initialized TreeBuilder with Config {config.log_config()}"
156
+ )
157
+
158
+ def create_node(
159
+ self, index: int, text: str, children_indices: Optional[Set[int]] = None
160
+ ) -> Tuple[int, Node]:
161
+ """Creates a new node with the given index, text, and (optionally) children indices.
162
+
163
+ Args:
164
+ index (int): The index of the new node.
165
+ text (str): The text associated with the new node.
166
+ children_indices (Optional[Set[int]]): A set of indices representing the children of the new node.
167
+ If not provided, an empty set will be used.
168
+
169
+ Returns:
170
+ Tuple[int, Node]: A tuple containing the index and the newly created node.
171
+ """
172
+ if children_indices is None:
173
+ children_indices = set()
174
+
175
+ embeddings = {
176
+ model_name: model.create_embedding(text)
177
+ for model_name, model in self.embedding_models.items()
178
+ }
179
+ return (index, Node(text, index, children_indices, embeddings))
180
+
181
+ def create_embedding(self, text) -> List[float]:
182
+ """
183
+ Generates embeddings for the given text using the specified embedding model.
184
+
185
+ Args:
186
+ text (str): The text for which to generate embeddings.
187
+
188
+ Returns:
189
+ List[float]: The generated embeddings.
190
+ """
191
+ return self.embedding_models[self.cluster_embedding_model].create_embedding(
192
+ text
193
+ )
194
+
195
+ def summarize(self, context, max_tokens=150) -> str:
196
+ """
197
+ Generates a summary of the input context using the specified summarization model.
198
+
199
+ Args:
200
+ context (str, optional): The context to summarize.
201
+ max_tokens (int, optional): The maximum number of tokens in the generated summary. Defaults to 150.o
202
+
203
+ Returns:
204
+ str: The generated summary.
205
+ """
206
+ return self.summarization_model.summarize(context, max_tokens)
207
+
208
+ def get_relevant_nodes(self, current_node, list_nodes) -> List[Node]:
209
+ """
210
+ Retrieves the top-k most relevant nodes to the current node from the list of nodes
211
+ based on cosine distance in the embedding space.
212
+
213
+ Args:
214
+ current_node (Node): The current node.
215
+ list_nodes (List[Node]): The list of nodes.
216
+
217
+ Returns:
218
+ List[Node]: The top-k most relevant nodes.
219
+ """
220
+ embeddings = get_embeddings(list_nodes, self.cluster_embedding_model)
221
+ distances = distances_from_embeddings(
222
+ current_node.embeddings[self.cluster_embedding_model], embeddings
223
+ )
224
+ indices = indices_of_nearest_neighbors_from_distances(distances)
225
+
226
+ if self.selection_mode == "threshold":
227
+ best_indices = [
228
+ index for index in indices if distances[index] > self.threshold
229
+ ]
230
+
231
+ elif self.selection_mode == "top_k":
232
+ best_indices = indices[: self.top_k]
233
+
234
+ nodes_to_add = [list_nodes[idx] for idx in best_indices]
235
+
236
+ return nodes_to_add
237
+
238
+ def multithreaded_create_leaf_nodes(self, chunks: List[str]) -> Dict[int, Node]:
239
+ """Creates leaf nodes using multithreading from the given list of text chunks.
240
+
241
+ Args:
242
+ chunks (List[str]): A list of text chunks to be turned into leaf nodes.
243
+
244
+ Returns:
245
+ Dict[int, Node]: A dictionary mapping node indices to the corresponding leaf nodes.
246
+ """
247
+ with ThreadPoolExecutor() as executor:
248
+ future_nodes = {
249
+ executor.submit(self.create_node, index, text): (index, text)
250
+ for index, text in enumerate(chunks)
251
+ }
252
+
253
+ leaf_nodes = {}
254
+ for future in as_completed(future_nodes):
255
+ index, node = future.result()
256
+ leaf_nodes[index] = node
257
+
258
+ return leaf_nodes
259
+
260
+ def build_from_text(self, text: str, use_multithreading: bool = True) -> Tree:
261
+ """Builds a golden tree from the input text, optionally using multithreading.
262
+
263
+ Args:
264
+ text (str): The input text.
265
+ use_multithreading (bool, optional): Whether to use multithreading when creating leaf nodes.
266
+ Default: True.
267
+
268
+ Returns:
269
+ Tree: The golden tree structure.
270
+ """
271
+ chunks = split_text(text, self.tokenizer, self.max_tokens)
272
+
273
+ logging.info("Creating Leaf Nodes")
274
+
275
+ if use_multithreading:
276
+ leaf_nodes = self.multithreaded_create_leaf_nodes(chunks)
277
+ else:
278
+ leaf_nodes = {}
279
+ for index, text in enumerate(chunks):
280
+ __, node = self.create_node(index, text)
281
+ leaf_nodes[index] = node
282
+
283
+ layer_to_nodes = {0: list(leaf_nodes.values())}
284
+
285
+ logging.info(f"Created {len(leaf_nodes)} Leaf Embeddings")
286
+
287
+ logging.info("Building All Nodes")
288
+
289
+ all_nodes = copy.deepcopy(leaf_nodes)
290
+
291
+ root_nodes = self.construct_tree(all_nodes, all_nodes, layer_to_nodes)
292
+
293
+ tree = Tree(all_nodes, root_nodes, leaf_nodes, self.num_layers, layer_to_nodes)
294
+
295
+ return tree
296
+
297
+ @abstractclassmethod
298
+ def construct_tree(
299
+ self,
300
+ current_level_nodes: Dict[int, Node],
301
+ all_tree_nodes: Dict[int, Node],
302
+ layer_to_nodes: Dict[int, List[Node]],
303
+ use_multithreading: bool = True,
304
+ ) -> Dict[int, Node]:
305
+ """
306
+ Constructs the hierarchical tree structure layer by layer by iteratively summarizing groups
307
+ of relevant nodes and updating the current_level_nodes and all_tree_nodes dictionaries at each step.
308
+
309
+ Args:
310
+ current_level_nodes (Dict[int, Node]): The current set of nodes.
311
+ all_tree_nodes (Dict[int, Node]): The dictionary of all nodes.
312
+ use_multithreading (bool): Whether to use multithreading to speed up the process.
313
+
314
+ Returns:
315
+ Dict[int, Node]: The final set of root nodes.
316
+ """
317
+ pass
318
+
319
+ # logging.info("Using Transformer-like TreeBuilder")
320
+
321
+ # def process_node(idx, current_level_nodes, new_level_nodes, all_tree_nodes, next_node_index, lock):
322
+ # relevant_nodes_chunk = self.get_relevant_nodes(
323
+ # current_level_nodes[idx], current_level_nodes
324
+ # )
325
+
326
+ # node_texts = get_text(relevant_nodes_chunk)
327
+
328
+ # summarized_text = self.summarize(
329
+ # context=node_texts,
330
+ # max_tokens=self.summarization_length,
331
+ # )
332
+
333
+ # logging.info(
334
+ # f"Node Texts Length: {len(self.tokenizer.encode(node_texts))}, Summarized Text Length: {len(self.tokenizer.encode(summarized_text))}"
335
+ # )
336
+
337
+ # next_node_index, new_parent_node = self.create_node(
338
+ # next_node_index,
339
+ # summarized_text,
340
+ # {node.index for node in relevant_nodes_chunk}
341
+ # )
342
+
343
+ # with lock:
344
+ # new_level_nodes[next_node_index] = new_parent_node
345
+
346
+ # for layer in range(self.num_layers):
347
+ # logging.info(f"Constructing Layer {layer}: ")
348
+
349
+ # node_list_current_layer = get_node_list(current_level_nodes)
350
+ # next_node_index = len(all_tree_nodes)
351
+
352
+ # new_level_nodes = {}
353
+ # lock = Lock()
354
+
355
+ # if use_multithreading:
356
+ # with ThreadPoolExecutor() as executor:
357
+ # for idx in range(0, len(node_list_current_layer)):
358
+ # executor.submit(process_node, idx, node_list_current_layer, new_level_nodes, all_tree_nodes, next_node_index, lock)
359
+ # next_node_index += 1
360
+ # executor.shutdown(wait=True)
361
+ # else:
362
+ # for idx in range(0, len(node_list_current_layer)):
363
+ # process_node(idx, node_list_current_layer, new_level_nodes, all_tree_nodes, next_node_index, lock)
364
+
365
+ # layer_to_nodes[layer + 1] = list(new_level_nodes.values())
366
+ # current_level_nodes = new_level_nodes
367
+ # all_tree_nodes.update(new_level_nodes)
368
+
369
+ # return new_level_nodes
baselines/raptor/raptor/tree_retriever.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Dict, List, Set
4
+
5
+ import tiktoken
6
+ from tenacity import retry, stop_after_attempt, wait_random_exponential
7
+
8
+ from .EmbeddingModels import BaseEmbeddingModel, OpenAIEmbeddingModel
9
+ from .Retrievers import BaseRetriever
10
+ from .tree_structures import Node, Tree
11
+ from .utils import (distances_from_embeddings, get_children, get_embeddings,
12
+ get_node_list, get_text,
13
+ indices_of_nearest_neighbors_from_distances,
14
+ reverse_mapping)
15
+
16
+ logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
17
+
18
+
19
+ class TreeRetrieverConfig:
20
+ def __init__(
21
+ self,
22
+ tokenizer=None,
23
+ threshold=None,
24
+ top_k=None,
25
+ selection_mode=None,
26
+ context_embedding_model=None,
27
+ embedding_model=None,
28
+ num_layers=None,
29
+ start_layer=None,
30
+ ):
31
+ if tokenizer is None:
32
+ tokenizer = tiktoken.get_encoding("cl100k_base")
33
+ self.tokenizer = tokenizer
34
+
35
+ if threshold is None:
36
+ threshold = 0.5
37
+ if not isinstance(threshold, float) or not (0 <= threshold <= 1):
38
+ raise ValueError("threshold must be a float between 0 and 1")
39
+ self.threshold = threshold
40
+
41
+ if top_k is None:
42
+ top_k = 5
43
+ if not isinstance(top_k, int) or top_k < 1:
44
+ raise ValueError("top_k must be an integer and at least 1")
45
+ self.top_k = top_k
46
+
47
+ if selection_mode is None:
48
+ selection_mode = "top_k"
49
+ if not isinstance(selection_mode, str) or selection_mode not in [
50
+ "top_k",
51
+ "threshold",
52
+ ]:
53
+ raise ValueError(
54
+ "selection_mode must be a string and either 'top_k' or 'threshold'"
55
+ )
56
+ self.selection_mode = selection_mode
57
+
58
+ if context_embedding_model is None:
59
+ context_embedding_model = "OpenAI"
60
+ if not isinstance(context_embedding_model, str):
61
+ raise ValueError("context_embedding_model must be a string")
62
+ self.context_embedding_model = context_embedding_model
63
+
64
+ if embedding_model is None:
65
+ embedding_model = OpenAIEmbeddingModel()
66
+ if not isinstance(embedding_model, BaseEmbeddingModel):
67
+ raise ValueError(
68
+ "embedding_model must be an instance of BaseEmbeddingModel"
69
+ )
70
+ self.embedding_model = embedding_model
71
+
72
+ if num_layers is not None:
73
+ if not isinstance(num_layers, int) or num_layers < 0:
74
+ raise ValueError("num_layers must be an integer and at least 0")
75
+ self.num_layers = num_layers
76
+
77
+ if start_layer is not None:
78
+ if not isinstance(start_layer, int) or start_layer < 0:
79
+ raise ValueError("start_layer must be an integer and at least 0")
80
+ self.start_layer = start_layer
81
+
82
+ def log_config(self):
83
+ config_log = """
84
+ TreeRetrieverConfig:
85
+ Tokenizer: {tokenizer}
86
+ Threshold: {threshold}
87
+ Top K: {top_k}
88
+ Selection Mode: {selection_mode}
89
+ Context Embedding Model: {context_embedding_model}
90
+ Embedding Model: {embedding_model}
91
+ Num Layers: {num_layers}
92
+ Start Layer: {start_layer}
93
+ """.format(
94
+ tokenizer=self.tokenizer,
95
+ threshold=self.threshold,
96
+ top_k=self.top_k,
97
+ selection_mode=self.selection_mode,
98
+ context_embedding_model=self.context_embedding_model,
99
+ embedding_model=self.embedding_model,
100
+ num_layers=self.num_layers,
101
+ start_layer=self.start_layer,
102
+ )
103
+ return config_log
104
+
105
+
106
+ class TreeRetriever(BaseRetriever):
107
+
108
+ def __init__(self, config, tree) -> None:
109
+ if not isinstance(tree, Tree):
110
+ raise ValueError("tree must be an instance of Tree")
111
+
112
+ if config.num_layers is not None and config.num_layers > tree.num_layers + 1:
113
+ raise ValueError(
114
+ "num_layers in config must be less than or equal to tree.num_layers + 1"
115
+ )
116
+
117
+ if config.start_layer is not None and config.start_layer > tree.num_layers:
118
+ raise ValueError(
119
+ "start_layer in config must be less than or equal to tree.num_layers"
120
+ )
121
+
122
+ self.tree = tree
123
+ self.num_layers = (
124
+ config.num_layers if config.num_layers is not None else tree.num_layers + 1
125
+ )
126
+ self.start_layer = (
127
+ config.start_layer if config.start_layer is not None else tree.num_layers
128
+ )
129
+
130
+ if self.num_layers > self.start_layer + 1:
131
+ raise ValueError("num_layers must be less than or equal to start_layer + 1")
132
+
133
+ self.tokenizer = config.tokenizer
134
+ self.top_k = config.top_k
135
+ self.threshold = config.threshold
136
+ self.selection_mode = config.selection_mode
137
+ self.embedding_model = config.embedding_model
138
+ self.context_embedding_model = config.context_embedding_model
139
+
140
+ self.tree_node_index_to_layer = reverse_mapping(self.tree.layer_to_nodes)
141
+
142
+ logging.info(
143
+ f"Successfully initialized TreeRetriever with Config {config.log_config()}"
144
+ )
145
+
146
+ def create_embedding(self, text: str) -> List[float]:
147
+ """
148
+ Generates embeddings for the given text using the specified embedding model.
149
+
150
+ Args:
151
+ text (str): The text for which to generate embeddings.
152
+
153
+ Returns:
154
+ List[float]: The generated embeddings.
155
+ """
156
+ return self.embedding_model.create_embedding(text)
157
+
158
+ def retrieve_information_collapse_tree(self, query: str, top_k: int, max_tokens: int) -> str:
159
+ """
160
+ Retrieves the most relevant information from the tree based on the query.
161
+
162
+ Args:
163
+ query (str): The query text.
164
+ max_tokens (int): The maximum number of tokens.
165
+
166
+ Returns:
167
+ str: The context created using the most relevant nodes.
168
+ """
169
+
170
+ query_embedding = self.create_embedding(query)
171
+
172
+ selected_nodes = []
173
+
174
+ node_list = get_node_list(self.tree.all_nodes)
175
+
176
+ embeddings = get_embeddings(node_list, self.context_embedding_model)
177
+
178
+ distances = distances_from_embeddings(query_embedding, embeddings)
179
+
180
+ indices = indices_of_nearest_neighbors_from_distances(distances)
181
+
182
+ total_tokens = 0
183
+ for idx in indices[:top_k]:
184
+
185
+ node = node_list[idx]
186
+ node_tokens = len(self.tokenizer.encode(node.text))
187
+
188
+ if total_tokens + node_tokens > max_tokens:
189
+ break
190
+
191
+ selected_nodes.append(node)
192
+ total_tokens += node_tokens
193
+
194
+ context = get_text(selected_nodes)
195
+ return selected_nodes, context
196
+
197
+ def retrieve_information(
198
+ self, current_nodes: List[Node], query: str, num_layers: int
199
+ ) -> str:
200
+ """
201
+ Retrieves the most relevant information from the tree based on the query.
202
+
203
+ Args:
204
+ current_nodes (List[Node]): A List of the current nodes.
205
+ query (str): The query text.
206
+ num_layers (int): The number of layers to traverse.
207
+
208
+ Returns:
209
+ str: The context created using the most relevant nodes.
210
+ """
211
+
212
+ query_embedding = self.create_embedding(query)
213
+
214
+ selected_nodes = []
215
+
216
+ node_list = current_nodes
217
+
218
+ for layer in range(num_layers):
219
+
220
+ embeddings = get_embeddings(node_list, self.context_embedding_model)
221
+
222
+ distances = distances_from_embeddings(query_embedding, embeddings)
223
+
224
+ indices = indices_of_nearest_neighbors_from_distances(distances)
225
+
226
+ if self.selection_mode == "threshold":
227
+ best_indices = [
228
+ index for index in indices if distances[index] > self.threshold
229
+ ]
230
+
231
+ elif self.selection_mode == "top_k":
232
+ best_indices = indices[: self.top_k]
233
+
234
+ nodes_to_add = [node_list[idx] for idx in best_indices]
235
+
236
+ selected_nodes.extend(nodes_to_add)
237
+
238
+ if layer != num_layers - 1:
239
+
240
+ child_nodes = []
241
+
242
+ for index in best_indices:
243
+ child_nodes.extend(node_list[index].children)
244
+
245
+ # take the unique values
246
+ child_nodes = list(dict.fromkeys(child_nodes))
247
+ node_list = [self.tree.all_nodes[i] for i in child_nodes]
248
+
249
+ context = get_text(selected_nodes)
250
+ return selected_nodes, context
251
+
252
+ def retrieve(
253
+ self,
254
+ query: str,
255
+ start_layer: int = None,
256
+ num_layers: int = None,
257
+ top_k: int = 10,
258
+ max_tokens: int = 3500,
259
+ collapse_tree: bool = True,
260
+ return_layer_information: bool = False,
261
+ ) -> str:
262
+ """
263
+ Queries the tree and returns the most relevant information.
264
+
265
+ Args:
266
+ query (str): The query text.
267
+ start_layer (int): The layer to start from. Defaults to self.start_layer.
268
+ num_layers (int): The number of layers to traverse. Defaults to self.num_layers.
269
+ max_tokens (int): The maximum number of tokens. Defaults to 3500.
270
+ collapse_tree (bool): Whether to retrieve information from all nodes. Defaults to False.
271
+
272
+ Returns:
273
+ str: The result of the query.
274
+ """
275
+
276
+ if not isinstance(query, str):
277
+ raise ValueError("query must be a string")
278
+
279
+ if not isinstance(max_tokens, int) or max_tokens < 1:
280
+ raise ValueError("max_tokens must be an integer and at least 1")
281
+
282
+ if not isinstance(collapse_tree, bool):
283
+ raise ValueError("collapse_tree must be a boolean")
284
+
285
+ # Set defaults
286
+ start_layer = self.start_layer if start_layer is None else start_layer
287
+ num_layers = self.num_layers if num_layers is None else num_layers
288
+
289
+ if not isinstance(start_layer, int) or not (
290
+ 0 <= start_layer <= self.tree.num_layers
291
+ ):
292
+ raise ValueError(
293
+ "start_layer must be an integer between 0 and tree.num_layers"
294
+ )
295
+
296
+ if not isinstance(num_layers, int) or num_layers < 1:
297
+ raise ValueError("num_layers must be an integer and at least 1")
298
+
299
+ if num_layers > (start_layer + 1):
300
+ raise ValueError("num_layers must be less than or equal to start_layer + 1")
301
+
302
+ if collapse_tree:
303
+ logging.info(f"Using collapsed_tree")
304
+ selected_nodes, context = self.retrieve_information_collapse_tree(
305
+ query, top_k, max_tokens
306
+ )
307
+ else:
308
+ layer_nodes = self.tree.layer_to_nodes[start_layer]
309
+ selected_nodes, context = self.retrieve_information(
310
+ layer_nodes, query, num_layers
311
+ )
312
+
313
+ if return_layer_information:
314
+
315
+ layer_information = []
316
+
317
+ for node in selected_nodes:
318
+ layer_information.append(
319
+ {
320
+ "node_index": node.index,
321
+ "layer_number": self.tree_node_index_to_layer[node.index],
322
+ }
323
+ )
324
+
325
+ return context, layer_information
326
+
327
+ return context
baselines/raptor/raptor/tree_structures.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Set
2
+
3
+
4
+ class Node:
5
+ """
6
+ Represents a node in the hierarchical tree structure.
7
+ """
8
+
9
+ def __init__(self, text: str, index: int, children: Set[int], embeddings) -> None:
10
+ self.text = text
11
+ self.index = index
12
+ self.children = children
13
+ self.embeddings = embeddings
14
+
15
+
16
+ class Tree:
17
+ """
18
+ Represents the entire hierarchical tree structure.
19
+ """
20
+
21
+ def __init__(
22
+ self, all_nodes, root_nodes, leaf_nodes, num_layers, layer_to_nodes
23
+ ) -> None:
24
+ self.all_nodes = all_nodes
25
+ self.root_nodes = root_nodes
26
+ self.leaf_nodes = leaf_nodes
27
+ self.num_layers = num_layers
28
+ self.layer_to_nodes = layer_to_nodes
baselines/raptor/raptor/utils.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ from typing import Dict, List, Set
4
+
5
+ import numpy as np
6
+ import tiktoken
7
+ from scipy import spatial
8
+
9
+ from .tree_structures import Node
10
+
11
+ logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
12
+
13
+
14
+ def reverse_mapping(layer_to_nodes: Dict[int, List[Node]]) -> Dict[Node, int]:
15
+ node_to_layer = {}
16
+ for layer, nodes in layer_to_nodes.items():
17
+ for node in nodes:
18
+ node_to_layer[node.index] = layer
19
+ return node_to_layer
20
+
21
+
22
+ def split_text(
23
+ text: str, tokenizer: tiktoken.get_encoding("cl100k_base"), max_tokens: int, overlap: int = 0
24
+ ):
25
+ """
26
+ Splits the input text into smaller chunks based on the tokenizer and maximum allowed tokens.
27
+
28
+ Args:
29
+ text (str): The text to be split.
30
+ tokenizer (CustomTokenizer): The tokenizer to be used for splitting the text.
31
+ max_tokens (int): The maximum allowed tokens.
32
+ overlap (int, optional): The number of overlapping tokens between chunks. Defaults to 0.
33
+
34
+ Returns:
35
+ List[str]: A list of text chunks.
36
+ """
37
+ # Split the text into sentences using multiple delimiters
38
+ delimiters = [".", "!", "?", "\n"]
39
+ regex_pattern = "|".join(map(re.escape, delimiters))
40
+ sentences = re.split(regex_pattern, text)
41
+
42
+ # Calculate the number of tokens for each sentence
43
+ n_tokens = [len(tokenizer.encode(" " + sentence)) for sentence in sentences]
44
+
45
+ chunks = []
46
+ current_chunk = []
47
+ current_length = 0
48
+
49
+ for sentence, token_count in zip(sentences, n_tokens):
50
+ # If the sentence is empty or consists only of whitespace, skip it
51
+ if not sentence.strip():
52
+ continue
53
+
54
+ # If the sentence is too long, split it into smaller parts
55
+ if token_count > max_tokens:
56
+ sub_sentences = re.split(r"[,;:]", sentence)
57
+
58
+ # there is no need to keep empty os only-spaced strings
59
+ # since spaces will be inserted in the beginning of the full string
60
+ # and in between the string in the sub_chuk list
61
+ filtered_sub_sentences = [sub.strip() for sub in sub_sentences if sub.strip() != ""]
62
+ sub_token_counts = [len(tokenizer.encode(" " + sub_sentence)) for sub_sentence in filtered_sub_sentences]
63
+
64
+ sub_chunk = []
65
+ sub_length = 0
66
+
67
+ for sub_sentence, sub_token_count in zip(filtered_sub_sentences, sub_token_counts):
68
+ if sub_length + sub_token_count > max_tokens:
69
+
70
+ # if the phrase does not have sub_sentences, it would create an empty chunk
71
+ # this big phrase would be added anyways in the next chunk append
72
+ if sub_chunk:
73
+ chunks.append(" ".join(sub_chunk))
74
+ sub_chunk = sub_chunk[-overlap:] if overlap > 0 else []
75
+ sub_length = sum(sub_token_counts[max(0, len(sub_chunk) - overlap):len(sub_chunk)])
76
+
77
+ sub_chunk.append(sub_sentence)
78
+ sub_length += sub_token_count
79
+
80
+ if sub_chunk:
81
+ chunks.append(" ".join(sub_chunk))
82
+
83
+ # If adding the sentence to the current chunk exceeds the max tokens, start a new chunk
84
+ elif current_length + token_count > max_tokens:
85
+ chunks.append(" ".join(current_chunk))
86
+ current_chunk = current_chunk[-overlap:] if overlap > 0 else []
87
+ current_length = sum(n_tokens[max(0, len(current_chunk) - overlap):len(current_chunk)])
88
+ current_chunk.append(sentence)
89
+ current_length += token_count
90
+
91
+ # Otherwise, add the sentence to the current chunk
92
+ else:
93
+ current_chunk.append(sentence)
94
+ current_length += token_count
95
+
96
+ # Add the last chunk if it's not empty
97
+ if current_chunk:
98
+ chunks.append(" ".join(current_chunk))
99
+
100
+ return chunks
101
+
102
+
103
+ def distances_from_embeddings(
104
+ query_embedding: List[float],
105
+ embeddings: List[List[float]],
106
+ distance_metric: str = "cosine",
107
+ ) -> List[float]:
108
+ """
109
+ Calculates the distances between a query embedding and a list of embeddings.
110
+
111
+ Args:
112
+ query_embedding (List[float]): The query embedding.
113
+ embeddings (List[List[float]]): A list of embeddings to compare against the query embedding.
114
+ distance_metric (str, optional): The distance metric to use for calculation. Defaults to 'cosine'.
115
+
116
+ Returns:
117
+ List[float]: The calculated distances between the query embedding and the list of embeddings.
118
+ """
119
+ distance_metrics = {
120
+ "cosine": spatial.distance.cosine,
121
+ "L1": spatial.distance.cityblock,
122
+ "L2": spatial.distance.euclidean,
123
+ "Linf": spatial.distance.chebyshev,
124
+ }
125
+
126
+ if distance_metric not in distance_metrics:
127
+ raise ValueError(
128
+ f"Unsupported distance metric '{distance_metric}'. Supported metrics are: {list(distance_metrics.keys())}"
129
+ )
130
+
131
+ distances = [
132
+ distance_metrics[distance_metric](query_embedding, embedding)
133
+ for embedding in embeddings
134
+ ]
135
+
136
+ return distances
137
+
138
+
139
+ def get_node_list(node_dict: Dict[int, Node]) -> List[Node]:
140
+ """
141
+ Converts a dictionary of node indices to a sorted list of nodes.
142
+
143
+ Args:
144
+ node_dict (Dict[int, Node]): Dictionary of node indices to nodes.
145
+
146
+ Returns:
147
+ List[Node]: Sorted list of nodes.
148
+ """
149
+ indices = sorted(node_dict.keys())
150
+ node_list = [node_dict[index] for index in indices]
151
+ return node_list
152
+
153
+
154
+ def get_embeddings(node_list: List[Node], embedding_model: str) -> List:
155
+ """
156
+ Extracts the embeddings of nodes from a list of nodes.
157
+
158
+ Args:
159
+ node_list (List[Node]): List of nodes.
160
+ embedding_model (str): The name of the embedding model to be used.
161
+
162
+ Returns:
163
+ List: List of node embeddings.
164
+ """
165
+ return [node.embeddings[embedding_model] for node in node_list]
166
+
167
+
168
+ def get_children(node_list: List[Node]) -> List[Set[int]]:
169
+ """
170
+ Extracts the children of nodes from a list of nodes.
171
+
172
+ Args:
173
+ node_list (List[Node]): List of nodes.
174
+
175
+ Returns:
176
+ List[Set[int]]: List of sets of node children indices.
177
+ """
178
+ return [node.children for node in node_list]
179
+
180
+
181
+ def get_text(node_list: List[Node]) -> str:
182
+ """
183
+ Generates a single text string by concatenating the text from a list of nodes.
184
+
185
+ Args:
186
+ node_list (List[Node]): List of nodes.
187
+
188
+ Returns:
189
+ str: Concatenated text.
190
+ """
191
+ text = ""
192
+ for node in node_list:
193
+ text += f"{' '.join(node.text.splitlines())}"
194
+ text += "\n\n"
195
+ return text
196
+
197
+
198
+ def indices_of_nearest_neighbors_from_distances(distances: List[float]) -> np.ndarray:
199
+ """
200
+ Returns the indices of nearest neighbors sorted in ascending order of distance.
201
+
202
+ Args:
203
+ distances (List[float]): A list of distances between embeddings.
204
+
205
+ Returns:
206
+ np.ndarray: An array of indices sorted by ascending distance.
207
+ """
208
+ return np.argsort(distances)
baselines/raptor/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ faiss-cpu
2
+ numpy==1.26.3
3
+ openai==1.3.3
4
+ scikit-learn
5
+ sentence-transformers==2.2.2
6
+ tenacity==8.2.3
7
+ tiktoken==0.5.1
8
+ torch
9
+ transformers==4.38.1
10
+ umap-learn==0.5.5
11
+ urllib3==1.26.6
baselines/raptor/run_raptor_baseline.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAPTOR baseline for the EvolV-Mem benchmark.
3
+
4
+ Builds a RAPTOR tree per question from session summaries, retrieves context
5
+ via collapsed-tree retrieval, and generates answers using Qwen-30B via vLLM.
6
+
7
+ Usage:
8
+ python baselines/raptor/run_raptor_baseline.py \
9
+ --in_file dataset/evolv_mem_v4.json \
10
+ --out_file output/raptor_qwen30b.jsonl \
11
+ --summary_file dataset/all_session_summary.json \
12
+ --profile_file metadata/generated_user_profile.json
13
+
14
+ Env vars:
15
+ VLLM_BASE_URL (default http://localhost:8000/v1)
16
+ VLLM_API_KEY (default EMPTY)
17
+ """
18
+
19
+ import argparse
20
+ import copy
21
+ import json
22
+ import logging
23
+ import os
24
+ import pickle
25
+ import re
26
+ import sys
27
+ import time
28
+ from abc import ABC, abstractmethod
29
+ from collections import defaultdict
30
+ from typing import Dict, List, Optional
31
+
32
+ from tqdm import tqdm
33
+
34
+ # ---------------------------------------------------------------------------
35
+ # Add the raptor package to the path so we can import it directly
36
+ # ---------------------------------------------------------------------------
37
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
38
+ sys.path.insert(0, SCRIPT_DIR)
39
+
40
+ from raptor import (
41
+ BaseEmbeddingModel,
42
+ BaseQAModel,
43
+ BaseSummarizationModel,
44
+ RetrievalAugmentation,
45
+ RetrievalAugmentationConfig,
46
+ SBertEmbeddingModel,
47
+ TreeRetriever,
48
+ )
49
+
50
+ logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # vLLM-backed Summarization Model
54
+ # ---------------------------------------------------------------------------
55
+
56
+ class VLLMSummarizationModel(BaseSummarizationModel):
57
+ """Summarization model backed by a vLLM OpenAI-compatible server."""
58
+
59
+ def __init__(
60
+ self,
61
+ model_name: str = None,
62
+ base_url: str = None,
63
+ api_key: str = None,
64
+ ):
65
+ from openai import OpenAI
66
+
67
+ self.model_name = (
68
+ model_name
69
+ or os.getenv("VLLM_MODEL_NAME")
70
+ or "Qwen/Qwen3-30B-A3B-Instruct-2507"
71
+ )
72
+ self.client = OpenAI(
73
+ base_url=base_url or os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1"),
74
+ api_key=api_key or os.getenv("VLLM_API_KEY", "EMPTY"),
75
+ )
76
+
77
+ def summarize(self, context, max_tokens=2048):
78
+ for attempt in range(6):
79
+ try:
80
+ response = self.client.chat.completions.create(
81
+ model=self.model_name,
82
+ messages=[
83
+ {"role": "system", "content": "You are a helpful assistant."},
84
+ {
85
+ "role": "user",
86
+ "content": (
87
+ "Write a summary of the following, including as many "
88
+ "key details as possible:\n\n"
89
+ f"{context}"
90
+ ),
91
+ },
92
+ ],
93
+ max_tokens=max_tokens,
94
+ temperature=0.3,
95
+ )
96
+ content = response.choices[0].message.content if response.choices else None
97
+ if content is None:
98
+ wait = min(2 ** attempt * 2, 30)
99
+ print(f"[WARN] LLM returned None content (attempt {attempt+1}); retrying in {wait}s")
100
+ time.sleep(wait)
101
+ continue
102
+ return content.strip()
103
+ except Exception as e:
104
+ msg = str(e).lower()
105
+ if any(code in msg for code in ("429", "500", "503", "rate limit")):
106
+ wait = min(2 ** attempt * 5, 60)
107
+ print(f"[WARN] Summarization retry {attempt+1}/6, sleeping {wait}s: {e}")
108
+ time.sleep(wait)
109
+ continue
110
+ print(f"[ERROR] Summarization failed: {e}")
111
+ raise
112
+ raise RuntimeError("Summarization failed after 6 retries")
113
+
114
+
115
+ # ---------------------------------------------------------------------------
116
+ # vLLM-backed QA Model
117
+ # ---------------------------------------------------------------------------
118
+
119
+ class VLLMQAModel(BaseQAModel):
120
+ """QA model backed by a vLLM OpenAI-compatible server.
121
+
122
+ Mutable attributes `question_date` and `user_profile` should be set
123
+ before each call to include per-question context in the prompt.
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ model_name: str = None,
129
+ base_url: str = None,
130
+ api_key: str = None,
131
+ ):
132
+ from openai import OpenAI
133
+
134
+ self.model_name = (
135
+ model_name
136
+ or os.getenv("VLLM_MODEL_NAME")
137
+ or "Qwen/Qwen3-30B-A3B-Instruct-2507"
138
+ )
139
+ self.client = OpenAI(
140
+ base_url=base_url or os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1"),
141
+ api_key=api_key or os.getenv("VLLM_API_KEY", "EMPTY"),
142
+ )
143
+ # Set these per-question before calling answer_question
144
+ self.question_date: Optional[str] = None
145
+ self.user_profile: Optional[str] = None
146
+
147
+ def answer_question(self, context, question):
148
+ # Build prompt matching the project's answer template (main.py:1490)
149
+ parts = []
150
+ parts.append(
151
+ "I will give you several chat history sessions between you and a user. "
152
+ "Please answer the question given the information."
153
+ )
154
+ if self.user_profile:
155
+ parts.append(f"\n\nUser Profile:\n{self.user_profile}")
156
+ parts.append(f"\n\nChat history sessions:\n\n{context}")
157
+ if self.question_date:
158
+ parts.append(f"\n\nCurrent Date: {self.question_date}")
159
+ parts.append(f"\nQuestion: {question}\nAnswer:")
160
+
161
+ prompt = "".join(parts)
162
+
163
+ for attempt in range(6):
164
+ try:
165
+ response = self.client.chat.completions.create(
166
+ model=self.model_name,
167
+ messages=[{"role": "user", "content": prompt}],
168
+ max_tokens=8192,
169
+ temperature=0.3,
170
+ )
171
+ content = response.choices[0].message.content if response.choices else None
172
+ if content is None:
173
+ wait = min(2 ** attempt * 2, 30)
174
+ print(f"[WARN] LLM returned None content (attempt {attempt+1}); retrying in {wait}s")
175
+ time.sleep(wait)
176
+ continue
177
+ return content.strip()
178
+ except Exception as e:
179
+ msg = str(e).lower()
180
+ if any(code in msg for code in ("429", "500", "503", "rate limit")):
181
+ wait = min(2 ** attempt * 5, 60)
182
+ print(f"[WARN] QA retry {attempt+1}/6, sleeping {wait}s: {e}")
183
+ time.sleep(wait)
184
+ continue
185
+ print(f"[ERROR] QA failed: {e}")
186
+ raise
187
+ raise RuntimeError("QA failed after 6 retries")
188
+
189
+
190
+ # ---------------------------------------------------------------------------
191
+ # Data helpers
192
+ # ---------------------------------------------------------------------------
193
+
194
+ def prepare_session_documents(
195
+ haystack_session_ids: List[str],
196
+ haystack_dates: List[str],
197
+ summaries: Dict,
198
+ ) -> List[str]:
199
+ """Format session summaries as RAPTOR leaf documents.
200
+
201
+ Each document is a short block:
202
+ [Session {sid} | Date: {date}]
203
+ {session_summary_text}
204
+ """
205
+ docs = []
206
+ for sid, date_str in zip(haystack_session_ids, haystack_dates):
207
+ summary_data = summaries.get(sid)
208
+ if summary_data is None:
209
+ continue
210
+ text = summary_data.get("session_summary", "")
211
+ if not text:
212
+ # Fallback: join turn summaries
213
+ turn_sums = summary_data.get("turn_summaries", [])
214
+ if turn_sums:
215
+ text = " ".join(turn_sums)
216
+ else:
217
+ continue
218
+ doc = f"[Session {sid} | Date: {date_str}]\n{text}"
219
+ docs.append(doc)
220
+ return docs
221
+
222
+
223
+ # ---------------------------------------------------------------------------
224
+ # Tree building / caching
225
+ # ---------------------------------------------------------------------------
226
+
227
+ def build_or_load_tree(
228
+ question_id: str,
229
+ docs: List[str],
230
+ tree_builder,
231
+ tree_cache_dir: str,
232
+ ):
233
+ """Build a RAPTOR tree from docs, or load from cache."""
234
+ tree_path = os.path.join(tree_cache_dir, f"{question_id}.pkl")
235
+
236
+ if os.path.exists(tree_path):
237
+ logging.info(f"Loading cached tree for {question_id}")
238
+ with open(tree_path, "rb") as f:
239
+ tree = pickle.load(f)
240
+ return tree
241
+
242
+ # Join docs with double-newline separator so RAPTOR's split_text keeps
243
+ # each ~230-token summary as a single leaf node (with tb_max_tokens=300).
244
+ text = "\n\n".join(docs)
245
+
246
+ logging.info(f"Building tree for {question_id} ({len(docs)} docs)")
247
+ tree = tree_builder.build_from_text(text, use_multithreading=True)
248
+
249
+ os.makedirs(tree_cache_dir, exist_ok=True)
250
+ with open(tree_path, "wb") as f:
251
+ pickle.dump(tree, f)
252
+ logging.info(f"Saved tree for {question_id} -> {tree_path}")
253
+
254
+ return tree
255
+
256
+
257
+ # ---------------------------------------------------------------------------
258
+ # Retrieval metrics
259
+ # ---------------------------------------------------------------------------
260
+
261
+ _SESSION_ID_RE = re.compile(r"\[Session\s+(\S+)\s*\|")
262
+
263
+
264
+ def extract_session_ids_from_context(context: str) -> List[str]:
265
+ """Parse session IDs from RAPTOR-retrieved context text.
266
+
267
+ Leaf nodes are formatted as '[Session {sid} | Date: ...]\\n{summary}'.
268
+ Higher-level nodes are summaries of clusters and won't contain session IDs.
269
+ """
270
+ return list(dict.fromkeys(_SESSION_ID_RE.findall(context))) # unique, order-preserving
271
+
272
+
273
+ def evaluate_retrieval(recalled_docs, correct_docs):
274
+ recall_any = float(any(doc in recalled_docs for doc in correct_docs))
275
+ recall_all = float(all(doc in recalled_docs for doc in correct_docs))
276
+ return recall_any, recall_all
277
+
278
+
279
+ def print_average_metrics(retrieval_metric_list):
280
+ metric_sums = defaultdict(float)
281
+ metric_counts = defaultdict(int)
282
+ for metric in retrieval_metric_list:
283
+ for k, v in metric.items():
284
+ metric_sums[k] += v
285
+ metric_counts[k] += 1
286
+ print(" Average retrieval metrics:")
287
+ for k in sorted(metric_sums):
288
+ avg = metric_sums[k] / metric_counts[k]
289
+ print(f" {k}: {avg:.4f}")
290
+
291
+
292
+ # ---------------------------------------------------------------------------
293
+ # Main
294
+ # ---------------------------------------------------------------------------
295
+
296
+ def main():
297
+ parser = argparse.ArgumentParser(description="RAPTOR baseline for EvolV-Mem")
298
+ parser.add_argument("--in_file", type=str, required=True,
299
+ help="Path to evolv_mem_v4.json")
300
+ parser.add_argument("--out_file", type=str, required=True,
301
+ help="Output JSONL file")
302
+ parser.add_argument("--summary_file", type=str, required=True,
303
+ help="Path to all_session_summary.json")
304
+ parser.add_argument("--profile_file", type=str, default=None,
305
+ help="Path to generated_user_profile.json")
306
+ parser.add_argument("--tree_cache_dir", type=str,
307
+ default="baselines/raptor/trees",
308
+ help="Directory to cache built trees")
309
+ # RAPTOR tree builder params
310
+ parser.add_argument("--tb_max_tokens", type=int, default=300,
311
+ help="Max tokens per leaf chunk (default 300)")
312
+ parser.add_argument("--tb_num_layers", type=int, default=3,
313
+ help="Number of tree layers (default 3)")
314
+ parser.add_argument("--tb_summarization_length", type=int, default=200,
315
+ help="Max tokens per cluster summary (default 200)")
316
+ # RAPTOR retrieval params
317
+ parser.add_argument("--tr_top_k", type=int, default=10,
318
+ help="Top-k nodes to retrieve (default 10)")
319
+ parser.add_argument("--max_retrieval_tokens", type=int, default=8000,
320
+ help="Token budget for retrieved context (default 8000)")
321
+ # Embedding model
322
+ parser.add_argument("--embedding_model", type=str,
323
+ default="sentence-transformers/multi-qa-mpnet-base-cos-v1",
324
+ help="SentenceTransformer model for embeddings")
325
+ # Index range (for parallel jobs)
326
+ parser.add_argument("--start_idx", type=int, default=None,
327
+ help="Start index (inclusive) for question subset")
328
+ parser.add_argument("--end_idx", type=int, default=None,
329
+ help="End index (exclusive) for question subset")
330
+ # Limit (for debugging)
331
+ parser.add_argument("--limit", type=int, default=None,
332
+ help="Process only the first N questions")
333
+ args = parser.parse_args()
334
+
335
+ # -----------------------------------------------------------------------
336
+ # Load data
337
+ # -----------------------------------------------------------------------
338
+ print(f"Loading benchmark from {args.in_file} ...")
339
+ with open(args.in_file) as f:
340
+ benchmark = json.load(f)
341
+ if args.start_idx is not None or args.end_idx is not None:
342
+ s = args.start_idx or 0
343
+ e = args.end_idx or len(benchmark)
344
+ benchmark = benchmark[s:e]
345
+ print(f" Using index range [{s}, {e})")
346
+ if args.limit:
347
+ benchmark = benchmark[: args.limit]
348
+ print(f" {len(benchmark)} questions loaded.")
349
+
350
+ print(f"Loading session summaries from {args.summary_file} ...")
351
+ with open(args.summary_file) as f:
352
+ summaries = json.load(f)
353
+ print(f" {len(summaries)} sessions loaded.")
354
+
355
+ profiles = {}
356
+ if args.profile_file and os.path.exists(args.profile_file):
357
+ print(f"Loading user profiles from {args.profile_file} ...")
358
+ with open(args.profile_file) as f:
359
+ profiles = json.load(f)
360
+ print(f" {len(profiles)} profiles loaded.")
361
+
362
+ # -----------------------------------------------------------------------
363
+ # Resume support: load existing output
364
+ # -----------------------------------------------------------------------
365
+ existing_qids = set()
366
+ if os.path.exists(args.out_file):
367
+ with open(args.out_file) as f:
368
+ for line in f:
369
+ line = line.strip()
370
+ if line:
371
+ obj = json.loads(line)
372
+ existing_qids.add(obj["question_id"])
373
+ print(f" Resuming: {len(existing_qids)} questions already processed.")
374
+
375
+ # -----------------------------------------------------------------------
376
+ # Initialize models
377
+ # -----------------------------------------------------------------------
378
+ print("Initializing models ...")
379
+ embedding_model = SBertEmbeddingModel(model_name=args.embedding_model)
380
+ summarization_model = VLLMSummarizationModel()
381
+ qa_model = VLLMQAModel()
382
+
383
+ # -----------------------------------------------------------------------
384
+ # Build RAPTOR config
385
+ # -----------------------------------------------------------------------
386
+ config = RetrievalAugmentationConfig(
387
+ summarization_model=summarization_model,
388
+ qa_model=qa_model,
389
+ embedding_model=embedding_model,
390
+ tb_max_tokens=args.tb_max_tokens,
391
+ tb_num_layers=args.tb_num_layers,
392
+ tb_summarization_length=args.tb_summarization_length,
393
+ tr_top_k=args.tr_top_k,
394
+ )
395
+
396
+ # Pre-create tree builder (reused across questions)
397
+ tree_builder = config.tree_builder_config
398
+ # We need the actual builder instance from a fresh RA to reuse
399
+ ra_template = RetrievalAugmentation(config=config)
400
+ tree_builder_instance = ra_template.tree_builder
401
+
402
+ os.makedirs(args.tree_cache_dir, exist_ok=True)
403
+
404
+ # -----------------------------------------------------------------------
405
+ # Process questions
406
+ # -----------------------------------------------------------------------
407
+ retrieval_metric_list = []
408
+ out_f = open(args.out_file, "a")
409
+
410
+ for di, entry in enumerate(tqdm(benchmark, desc="RAPTOR baseline")):
411
+ qid = entry["question_id"]
412
+ question = entry["question"]
413
+ question_date = entry["question_date"]
414
+
415
+ if qid in existing_qids:
416
+ continue
417
+
418
+ try:
419
+ # 1. Prepare documents from session summaries
420
+ docs = prepare_session_documents(
421
+ entry["haystack_session_ids"],
422
+ entry["haystack_dates"],
423
+ summaries,
424
+ )
425
+
426
+ if not docs:
427
+ print(f"[WARN] q_idx={di} qid={qid}: no session summaries found, skipping.")
428
+ result = {
429
+ "q_idx": di,
430
+ "question_id": qid,
431
+ "hypothesis": "Insufficient information to answer.",
432
+ "n_docs": 0,
433
+ }
434
+ print(json.dumps(result), file=out_f, flush=True)
435
+ continue
436
+
437
+ # 2. Build or load RAPTOR tree
438
+ tree = build_or_load_tree(
439
+ qid, docs, tree_builder_instance, args.tree_cache_dir
440
+ )
441
+
442
+ # 3. Create RA instance with this tree
443
+ ra = RetrievalAugmentation(config=config, tree=tree)
444
+
445
+ # 4. Set per-question context on the QA model
446
+ qa_model.question_date = question_date
447
+ user_id = qid.split("_q_")[0] if "_q_" in qid else qid
448
+ qa_model.user_profile = profiles.get(user_id, None)
449
+
450
+ # 5. Retrieve context (separate from QA so we can extract session IDs)
451
+ context, layer_info = ra.retrieve(
452
+ question=question,
453
+ top_k=args.tr_top_k,
454
+ max_tokens=args.max_retrieval_tokens,
455
+ collapse_tree=True,
456
+ return_layer_information=True,
457
+ )
458
+
459
+ # 5a. Extract retrieved session IDs from context text
460
+ retrieved_session_ids = extract_session_ids_from_context(context)
461
+
462
+ # 5b. Generate answer manually
463
+ answer = qa_model.answer_question(context, question)
464
+
465
+ # 5c. Compute retrieval metrics
466
+ answer_session_ids = entry.get("answer_session_ids", [])
467
+ retrieval_metric = {}
468
+ if answer_session_ids and retrieved_session_ids:
469
+ for topk in [5, 10, 20, 30]:
470
+ r_any, r_all = evaluate_retrieval(
471
+ retrieved_session_ids[:topk], answer_session_ids
472
+ )
473
+ retrieval_metric[f"recall_any@{topk}"] = r_any
474
+ retrieval_metric[f"recall_all@{topk}"] = r_all
475
+ retrieval_metric_list.append(retrieval_metric)
476
+ print_average_metrics(retrieval_metric_list)
477
+
478
+ # 6. Write output
479
+ result = {
480
+ "q_idx": di,
481
+ "question_id": qid,
482
+ "hypothesis": answer,
483
+ "n_docs": len(docs),
484
+ "n_tree_nodes": len(tree.all_nodes) if hasattr(tree, "all_nodes") else -1,
485
+ "n_tree_layers": tree.num_layers if hasattr(tree, "num_layers") else -1,
486
+ "retrieved_session_ids": retrieved_session_ids,
487
+ "retrieval_metric": retrieval_metric,
488
+ }
489
+ print(json.dumps(result), file=out_f, flush=True)
490
+
491
+ print(
492
+ f"[{di}/{len(benchmark)}] qid={qid} | "
493
+ f"docs={len(docs)} | nodes={result['n_tree_nodes']} | "
494
+ f"layers={result['n_tree_layers']} | "
495
+ f"retrieved_sessions={len(retrieved_session_ids)}"
496
+ )
497
+ print(f" Q: {question[:100]}...")
498
+ print(f" A: {answer[:200]}...")
499
+
500
+ except Exception as e:
501
+ print(f"[ERROR] q_idx={di} qid={qid} failed: {e}", flush=True)
502
+ import traceback
503
+ traceback.print_exc()
504
+ continue
505
+
506
+ out_f.close()
507
+ print(f"\nDone. Results saved to {args.out_file}")
508
+
509
+
510
+ if __name__ == "__main__":
511
+ main()
baselines/read-agent/read_agent_demo.ipynb ADDED
@@ -0,0 +1,976 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "toc_visible": true
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ }
16
+ },
17
+ "cells": [
18
+ {
19
+ "cell_type": "markdown",
20
+ "source": [
21
+ "![read_agent_teaser](https://read-agent.github.io/img/teaser.png)"
22
+ ],
23
+ "metadata": {
24
+ "id": "1iqyV7VcsiXT"
25
+ }
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {
31
+ "id": "MYOnCMh83ZRE"
32
+ },
33
+ "outputs": [],
34
+ "source": [
35
+ "!wget https://github.com/nyu-mll/quality/raw/main/data/v1.0.1/QuALITY.v1.0.1.htmlstripped.dev\n",
36
+ "import re, time, datetime, json, string, copy"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "source": [
42
+ "# @title Using OpenAI GPT model (DO NOT run the next cell if using GPT)\n",
43
+ "!pip3 install openai\n",
44
+ "import openai\n",
45
+ "\n",
46
+ "key = 'YOUR API KEY' #@param {type: \"string\"}\n",
47
+ "gpt_client = openai.OpenAI(api_key=key)\n",
48
+ "model_type = 'gpt'\n",
49
+ "\n",
50
+ "def query_gpt_model(\n",
51
+ " prompt: str,\n",
52
+ " lm: str = 'gpt-3.5-turbo-1106',\n",
53
+ " temperature: float = 0.0,\n",
54
+ " max_decode_steps: int = 512,\n",
55
+ " seconds_to_reset_tokens: float = 30.0,\n",
56
+ ") -> str:\n",
57
+ " while True:\n",
58
+ " try:\n",
59
+ " raw_response = gpt_client.chat.completions.with_raw_response.create(\n",
60
+ " model=lm,\n",
61
+ " max_tokens=max_decode_steps,\n",
62
+ " temperature=temperature,\n",
63
+ " messages=[\n",
64
+ " {'role': 'user', 'content': prompt},\n",
65
+ " ]\n",
66
+ " )\n",
67
+ " completion = raw_response.parse()\n",
68
+ " return completion.choices[0].message.content\n",
69
+ " except openai.RateLimitError as e:\n",
70
+ " print(f'{datetime.datetime.now()}: query_gpt_model: RateLimitError {e.message}: {e}')\n",
71
+ " time.sleep(seconds_to_reset_tokens)\n",
72
+ " except openai.APIError as e:\n",
73
+ " print(f'{datetime.datetime.now()}: query_gpt_model: APIError {e.message}: {e}')\n",
74
+ " print(f'{datetime.datetime.now()}: query_gpt_model: Retrying after 5 seconds...')\n",
75
+ " time.sleep(5)"
76
+ ],
77
+ "metadata": {
78
+ "id": "oz0kOxYJ4n3e",
79
+ "cellView": "form"
80
+ },
81
+ "execution_count": null,
82
+ "outputs": []
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "source": [
87
+ "# @title Using Google Gemini model (DO NOT run this if using GPT)\n",
88
+ "!pip3 install -q -U google-generativeai\n",
89
+ "import google.generativeai as genai\n",
90
+ "\n",
91
+ "key = 'YOUR API KEY' #@param {type: \"string\"}\n",
92
+ "\n",
93
+ "genai.configure(api_key=key)\n",
94
+ "model = genai.GenerativeModel('gemini-pro')\n",
95
+ "model_type = 'gemini'\n",
96
+ "\n",
97
+ "def query_gemini_model(\n",
98
+ " prompt: str,\n",
99
+ " retries: int = 10,\n",
100
+ ") -> str:\n",
101
+ " while True and retries > 0:\n",
102
+ " try:\n",
103
+ " response = model.generate_content(prompt)\n",
104
+ " text_response = response.text.replace(\"**\", \"\")\n",
105
+ " return text_response\n",
106
+ " except Exception as e:\n",
107
+ " print(f'{datetime.datetime.now()}: query_gemini_model: Error: {e}')\n",
108
+ " print(f'{datetime.datetime.now()}: query_gemini_model: Retrying after 5 seconds...')\n",
109
+ " retries -= 1\n",
110
+ " time.sleep(5)"
111
+ ],
112
+ "metadata": {
113
+ "cellView": "form",
114
+ "id": "YcP_tIpZKNFY"
115
+ },
116
+ "execution_count": null,
117
+ "outputs": []
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "source": [
122
+ "def query_model(prompt):\n",
123
+ " if model_type == \"gpt\":\n",
124
+ " return query_gpt_model(prompt)\n",
125
+ " elif model_type == \"gemini\":\n",
126
+ " return query_gemini_model(prompt)"
127
+ ],
128
+ "metadata": {
129
+ "id": "pYm2GsBGEvAI"
130
+ },
131
+ "execution_count": null,
132
+ "outputs": []
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "source": [
137
+ "#@title Load a QuALITY example\n",
138
+ "\n",
139
+ "# Fields that are straight text copies from raw example to processed example.\n",
140
+ "_ONE2ONE_FIELDS = (\n",
141
+ " 'article',\n",
142
+ " 'article_id',\n",
143
+ " 'set_unique_id',\n",
144
+ " 'writer_id',\n",
145
+ " 'source',\n",
146
+ " 'title',\n",
147
+ " 'topic',\n",
148
+ " 'url',\n",
149
+ " 'writer_id',\n",
150
+ " 'author',\n",
151
+ ")\n",
152
+ "\n",
153
+ "quality_dev = []\n",
154
+ "\n",
155
+ "with open('QuALITY.v1.0.1.htmlstripped.dev', 'r') as f:\n",
156
+ " for line in f.readlines():\n",
157
+ " j = json.loads(line)\n",
158
+ " fields = {k: j[k] for k in _ONE2ONE_FIELDS}\n",
159
+ " fields.update({\n",
160
+ " 'questions': [q['question'] for q in j['questions']],\n",
161
+ " 'question_ids': [q['question_unique_id'] for q in j['questions']],\n",
162
+ " 'difficults': [q['difficult'] for q in j['questions']],\n",
163
+ " 'options': [q['options'] for q in j['questions']],\n",
164
+ " })\n",
165
+ "\n",
166
+ " fields.update({\n",
167
+ " 'gold_labels': [q['gold_label'] for q in j['questions']],\n",
168
+ " 'writer_labels': [q['writer_label'] for q in j['questions']],\n",
169
+ " })\n",
170
+ "\n",
171
+ " quality_dev.append(fields)\n",
172
+ "\n",
173
+ "example = quality_dev[13]"
174
+ ],
175
+ "metadata": {
176
+ "id": "1B70Rqg97aXu",
177
+ "cellView": "form"
178
+ },
179
+ "execution_count": null,
180
+ "outputs": []
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "source": [
185
+ "#@title Helper functions\n",
186
+ "\n",
187
+ "all_lowercase_letters = string.ascii_lowercase # \"abcd...xyz\"\n",
188
+ "bracketed_lowercase_letters_set = set(\n",
189
+ " [f\"({l})\" for l in all_lowercase_letters]\n",
190
+ ") # {\"(a)\", ...}\n",
191
+ "bracketed_uppercase_letters_set = set(\n",
192
+ " [f\"({l.upper()})\" for l in all_lowercase_letters]\n",
193
+ ") # {\"(a)\", ...}\n",
194
+ "\n",
195
+ "choices = ['(A)', '(B)', '(C)', '(D)']\n",
196
+ "\n",
197
+ "def get_index_from_symbol(answer):\n",
198
+ " \"\"\"Get the index from the letter symbols A, B, C, D, to extract answer texts.\n",
199
+ "\n",
200
+ " Args:\n",
201
+ " answer (str): the string of answer like \"(B)\".\n",
202
+ "\n",
203
+ " Returns:\n",
204
+ " index (int): how far the given choice is from \"a\", like 1 for answer \"(B)\".\n",
205
+ " \"\"\"\n",
206
+ " answer = str(answer).lower()\n",
207
+ " # extract the choice letter from within bracket\n",
208
+ " if answer in bracketed_lowercase_letters_set:\n",
209
+ " answer = re.findall(r\"\\(.*?\\)\", answer)[0][1]\n",
210
+ " index = ord(answer) - ord(\"a\")\n",
211
+ " return index\n",
212
+ "\n",
213
+ "def count_words(text):\n",
214
+ " \"\"\"Simple word counting.\"\"\"\n",
215
+ " return len(text.split())\n",
216
+ "\n",
217
+ "def quality_gutenberg_parser(raw_article):\n",
218
+ " \"\"\"Parse Gutenberg articles in the QuALITY dataset.\"\"\"\n",
219
+ " lines = []\n",
220
+ " previous_line = None\n",
221
+ " for i, line in enumerate(raw_article.split('\\n')):\n",
222
+ " line = line.strip()\n",
223
+ " original_line = line\n",
224
+ " if line == '':\n",
225
+ " if previous_line == '':\n",
226
+ " line = '\\n'\n",
227
+ " else:\n",
228
+ " previous_line = original_line\n",
229
+ " continue\n",
230
+ " previous_line = original_line\n",
231
+ " lines.append(line)\n",
232
+ " return ' '.join(lines)"
233
+ ],
234
+ "metadata": {
235
+ "id": "nQsb3n6pOlz2",
236
+ "cellView": "form"
237
+ },
238
+ "execution_count": null,
239
+ "outputs": []
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "source": [
244
+ "#@title ReadAgent (1) Episode Pagination\n",
245
+ "\n",
246
+ "prompt_pagination_template = \"\"\"\n",
247
+ "You are given a passage that is taken from a larger text (article, book, ...) and some numbered labels between the paragraphs in the passage.\n",
248
+ "Numbered label are in angeled brackets. For example, if the label number is 19, it shows as <19> in text.\n",
249
+ "Please choose one label that it is natural to break reading.\n",
250
+ "Such point can be scene transition, end of a dialogue, end of an argument, narrative transition, etc.\n",
251
+ "Please answer the break point label and explain.\n",
252
+ "For example, if <57> is a good point to break, answer with \\\"Break point: <57>\\n Because ...\\\"\n",
253
+ "\n",
254
+ "Passage:\n",
255
+ "\n",
256
+ "{0}\n",
257
+ "{1}\n",
258
+ "{2}\n",
259
+ "\n",
260
+ "\"\"\"\n",
261
+ "\n",
262
+ "def parse_pause_point(text):\n",
263
+ " text = text.strip(\"Break point: \")\n",
264
+ " if text[0] != '<':\n",
265
+ " return None\n",
266
+ " for i, c in enumerate(text):\n",
267
+ " if c == '>':\n",
268
+ " if text[1:i].isnumeric():\n",
269
+ " return int(text[1:i])\n",
270
+ " else:\n",
271
+ " return None\n",
272
+ " return None\n",
273
+ "\n",
274
+ "\n",
275
+ "def quality_pagination(example,\n",
276
+ " word_limit=600,\n",
277
+ " start_threshold=280,\n",
278
+ " max_retires=10,\n",
279
+ " verbose=True,\n",
280
+ " allow_fallback_to_last=True):\n",
281
+ " article = example['article']\n",
282
+ " title = example['title']\n",
283
+ " print(f\"[Pagination][Article {title}]\")\n",
284
+ " paragraphs = quality_gutenberg_parser(article).split('\\n')\n",
285
+ "\n",
286
+ " i = 0\n",
287
+ " pages = []\n",
288
+ " while i < len(paragraphs):\n",
289
+ " preceding = \"\" if i == 0 else \"...\\n\" + '\\n'.join(pages[-1])\n",
290
+ " passage = [paragraphs[i]]\n",
291
+ " wcount = count_words(paragraphs[i])\n",
292
+ " j = i + 1\n",
293
+ " while wcount < word_limit and j < len(paragraphs):\n",
294
+ " wcount += count_words(paragraphs[j])\n",
295
+ " if wcount >= start_threshold:\n",
296
+ " passage.append(f\"<{j}>\")\n",
297
+ " passage.append(paragraphs[j])\n",
298
+ " j += 1\n",
299
+ " passage.append(f\"<{j}>\")\n",
300
+ " end_tag = \"\" if j == len(paragraphs) else paragraphs[j] + \"\\n...\"\n",
301
+ "\n",
302
+ " pause_point = None\n",
303
+ " if wcount < 350:\n",
304
+ " pause_point = len(paragraphs)\n",
305
+ " else:\n",
306
+ " prompt = prompt_pagination_template.format(preceding, '\\n'.join(passage), end_tag)\n",
307
+ " response = query_model(prompt=prompt).strip()\n",
308
+ " pause_point = parse_pause_point(response)\n",
309
+ " if pause_point and (pause_point <= i or pause_point > j):\n",
310
+ " print(f\"prompt:\\n{prompt},\\nresponse:\\n{response}\\n\")\n",
311
+ " print(f\"i:{i} j:{j} pause_point:{pause_point}\")\n",
312
+ " pause_point = None\n",
313
+ " if pause_point is None:\n",
314
+ " if allow_fallback_to_last:\n",
315
+ " pause_point = j\n",
316
+ " else:\n",
317
+ " raise ValueError(f\"prompt:\\n{prompt},\\nresponse:\\n{response}\\n\")\n",
318
+ "\n",
319
+ " page = paragraphs[i:pause_point]\n",
320
+ " pages.append(page)\n",
321
+ " if verbose:\n",
322
+ " print(f\"Paragraph {i}-{pause_point-1}\", page)\n",
323
+ " i = pause_point\n",
324
+ " print(f\"[Pagination] Done with {len(pages)} pages\")\n",
325
+ " return pages\n",
326
+ "\n",
327
+ "pages = quality_pagination(example)"
328
+ ],
329
+ "metadata": {
330
+ "id": "BfFkEQKx0u9U"
331
+ },
332
+ "execution_count": null,
333
+ "outputs": []
334
+ },
335
+ {
336
+ "cell_type": "code",
337
+ "source": [
338
+ "#@title ReadAgent (2) Memory Gisting\n",
339
+ "\n",
340
+ "prompt_shorten_template = \"\"\"\n",
341
+ "Please shorten the following passage.\n",
342
+ "Just give me a shortened version. DO NOT explain your reason.\n",
343
+ "\n",
344
+ "Passage:\n",
345
+ "{}\n",
346
+ "\n",
347
+ "\"\"\"\n",
348
+ "\n",
349
+ "def quality_gisting(example, pages, word_limit=600, start_threshold=280, verbose=True):\n",
350
+ " article = example['article']\n",
351
+ " title = example['title']\n",
352
+ " word_count = count_words(article)\n",
353
+ " print(f\"[Gisting][Article {title}], {word_count} words\")\n",
354
+ "\n",
355
+ " shortened_pages = []\n",
356
+ " for i, page in enumerate(pages):\n",
357
+ " prompt = prompt_shorten_template.format('\\n'.join(page))\n",
358
+ " response = query_model(prompt)\n",
359
+ " shortened_text = response.strip()\n",
360
+ " shortened_pages.append(shortened_text)\n",
361
+ " if verbose:\n",
362
+ " print(\"[gist] page {}:\".format(i), shortened_text, flush=True)\n",
363
+ " shortened_article = '\\n'.join(shortened_pages)\n",
364
+ " gist_word_count = count_words(shortened_article)\n",
365
+ " if verbose:\n",
366
+ " print(\"Shortened article:\\n\", shortened_article, flush=True)\n",
367
+ " output = copy.deepcopy(example)\n",
368
+ " output.update({'title': title, 'word_count': word_count, 'gist_word_count': gist_word_count, 'shortened_pages': shortened_pages, 'pages': pages})\n",
369
+ " if verbose:\n",
370
+ " print(f\"compression rate {round(100.0 - gist_word_count/word_count*100, 2)}% ({gist_word_count}/{word_count})\")\n",
371
+ " return output\n",
372
+ "example_with_gists = quality_gisting(example, pages)"
373
+ ],
374
+ "metadata": {
375
+ "id": "DLBolKnkS_9y"
376
+ },
377
+ "execution_count": null,
378
+ "outputs": []
379
+ },
380
+ {
381
+ "cell_type": "code",
382
+ "source": [
383
+ "#@title ReadAgent (3) Look-Up\n",
384
+ "\n",
385
+ "prompt_lookup_template = \"\"\"\n",
386
+ "The following text is what you remembered from reading an article and a multiple choice question related to it.\n",
387
+ "You may read 1 to 6 page(s) of the article again to refresh your memory to prepare yourselve for the question.\n",
388
+ "Please respond with which page(s) you would like to read.\n",
389
+ "For example, if your only need to read Page 8, respond with \\\"I want to look up Page [8] to ...\\\";\n",
390
+ "if your would like to read Page 7 and 12, respond with \\\"I want to look up Page [7, 12] to ...\\\";\n",
391
+ "if your would like to read Page 2, 3, 7, 15 and 18, respond with \\\"I want to look up Page [2, 3, 7, 15, 18] to ...\\\".\n",
392
+ "if your would like to read Page 3, 4, 5, 12, 13 and 16, respond with \\\"I want to look up Page [3, 3, 4, 12, 13, 16] to ...\\\".\n",
393
+ "DO NOT select more pages if you don't need to.\n",
394
+ "DO NOT answer the question yet.\n",
395
+ "\n",
396
+ "Text:\n",
397
+ "{}\n",
398
+ "\n",
399
+ "Question:\n",
400
+ "{}\n",
401
+ "{}\n",
402
+ "\n",
403
+ "Take a deep breath and tell me: Which page(s) would you like to read again?\n",
404
+ "\"\"\"\n",
405
+ "\n",
406
+ "prompt_answer_template = \"\"\"\n",
407
+ "Read the following article and answer a multiple choice question.\n",
408
+ "For example, if (C) is correct, answer with \\\"Answer: (C) ...\\\"\n",
409
+ "\n",
410
+ "Article:\n",
411
+ "{}\n",
412
+ "\n",
413
+ "Question:\n",
414
+ "{}\n",
415
+ "{}\n",
416
+ "\n",
417
+ "\"\"\"\n",
418
+ "\n",
419
+ "def quality_parallel_lookup(example, verbose=True):\n",
420
+ " preprocessed_pages = example['pages']\n",
421
+ " article = example['article']\n",
422
+ " title = example['title']\n",
423
+ " word_count = example['word_count']\n",
424
+ " gist_word_count = example['gist_word_count']\n",
425
+ " pages = example['pages']\n",
426
+ " shortened_pages = example['shortened_pages']\n",
427
+ " questions = example['questions']\n",
428
+ " options = example['options']\n",
429
+ " gold_labels = example['gold_labels'] # numerical [1, 2, 3, 4]\n",
430
+ "\n",
431
+ " print(f\"[Look-Up][Article {title}] {word_count} words\")\n",
432
+ "\n",
433
+ " model_choices = []\n",
434
+ " lookup_page_ids = []\n",
435
+ "\n",
436
+ " shortened_pages_pidx = []\n",
437
+ " for i, shortened_text in enumerate(shortened_pages):\n",
438
+ " shortened_pages_pidx.append(\"<Page {}>\\n\".format(i) + shortened_text)\n",
439
+ " shortened_article = '\\n'.join(shortened_pages_pidx)\n",
440
+ "\n",
441
+ " expanded_gist_word_counts = []\n",
442
+ " for i, label in enumerate(gold_labels):\n",
443
+ " # only test the first question for demo\n",
444
+ " if i != 1:\n",
445
+ " continue\n",
446
+ " q = questions[i]\n",
447
+ " print(\"question: \", q)\n",
448
+ " options_i = [f\"{ol} {o}\" for ol, o in zip(choices, options[i])]\n",
449
+ " print(\"options: \", \"\\n\".join(options_i))\n",
450
+ " prompt_lookup = prompt_lookup_template.format(shortened_article, q, '\\n'.join(options_i))\n",
451
+ "\n",
452
+ " page_ids = []\n",
453
+ "\n",
454
+ " response = query_model(prompt=prompt_lookup).strip()\n",
455
+ "\n",
456
+ " try: start = response.index('[')\n",
457
+ " except ValueError: start = len(response)\n",
458
+ " try: end = response.index(']')\n",
459
+ " except ValueError: end = 0\n",
460
+ " if start < end:\n",
461
+ " page_ids_str = response[start+1:end].split(',')\n",
462
+ " page_ids = []\n",
463
+ " for p in page_ids_str:\n",
464
+ " if p.strip().isnumeric():\n",
465
+ " page_id = int(p)\n",
466
+ " if page_id < 0 or page_id >= len(pages):\n",
467
+ " print(\"Skip invalid page number: \", page_id, flush=True)\n",
468
+ " else:\n",
469
+ " page_ids.append(page_id)\n",
470
+ "\n",
471
+ " if verbose:\n",
472
+ " print(\"Model chose to look up page {}\".format(page_ids))\n",
473
+ "\n",
474
+ " # Memory expansion after look-up, replacing the target shortened page with the original page\n",
475
+ " expanded_shortened_pages = shortened_pages[:]\n",
476
+ " if len(page_ids) > 0:\n",
477
+ " for page_id in page_ids:\n",
478
+ " expanded_shortened_pages[page_id] = '\\n'.join(pages[page_id])\n",
479
+ "\n",
480
+ " expanded_shortened_article = '\\n'.join(expanded_shortened_pages)\n",
481
+ " expanded_gist_word_count = count_words(expanded_shortened_article)\n",
482
+ " if verbose:\n",
483
+ " print(\"Expanded shortened article:\\n\", expanded_shortened_article, flush=True)\n",
484
+ " prompt_answer = prompt_answer_template.format(expanded_shortened_article, q, '\\n'.join(options_i))\n",
485
+ "\n",
486
+ " # If the response doesn't follow the template, retry\n",
487
+ " model_choice = None\n",
488
+ " response = query_model(prompt=prompt_answer)\n",
489
+ " response = response.strip()\n",
490
+ " for j, choice in enumerate(choices):\n",
491
+ " if response.startswith(f\"Answer: {choice}\") or response.startswith(f\"Answer: {choice[1]}\"):\n",
492
+ " model_choice = j+1\n",
493
+ " break\n",
494
+ " is_correct = 1 if model_choice == label else 0\n",
495
+ " print(f\"question: {q}\")\n",
496
+ " print(f\"reference answer: {choices[label]}, model prediction: {choices[model_choice]}, is_correct: {is_correct}\")\n",
497
+ " print(f\"compression rate {round(100.0 - gist_word_count/word_count*100, 2)}% ({gist_word_count}/{word_count})\")\n",
498
+ " print(f\"compression rate after look-up {round(100.0 - expanded_gist_word_count/word_count*100, 2)}% ({expanded_gist_word_count}/{word_count})\")\n",
499
+ "\n",
500
+ "quality_parallel_lookup(example_with_gists)"
501
+ ],
502
+ "metadata": {
503
+ "id": "8YKNTyDsXNIn"
504
+ },
505
+ "execution_count": null,
506
+ "outputs": []
507
+ },
508
+ {
509
+ "cell_type": "markdown",
510
+ "source": [
511
+ "#Prompts that we used in the paper\n",
512
+ "\n",
513
+ "In the following we show the prompts that were used for the QuALTIY, QMSum, NarrativeQA datasets with the PaLM 2-L model. While there are slight differences in prompt design, most of these are not due to optimizing prompts for specific datasets but rather a results of that each author wrote the prompts independently."
514
+ ],
515
+ "metadata": {
516
+ "id": "Gn8fjomx7iRz"
517
+ }
518
+ },
519
+ {
520
+ "cell_type": "code",
521
+ "source": [
522
+ "# @title The prompts we used for QuALITY with PaLM 2-L\n",
523
+ "\n",
524
+ "\n",
525
+ "# Pagination\n",
526
+ "pagination_prompt_template = \"\"\"\n",
527
+ "You are given a passage that is taken from a larger text (article, book, ...) and some numbered labels between the paragraphs in the passage.\n",
528
+ "Numbered label are in angeled brackets. For example, if the label number is 19, it shows as <19> in text.\n",
529
+ "Please choose one label that it is natural to break reading.\n",
530
+ "Such point can be scene transition, end of a dialogue, end of an argument, narrative transition, etc.\n",
531
+ "Please answer the break point label and explain.\n",
532
+ "For example, if <57> is a good point to break, answer with \\\"Break point: <57>\\n Because ...\\\"\n",
533
+ "\n",
534
+ "Passage:\n",
535
+ "\n",
536
+ "{passage_text}\n",
537
+ "{end_tag}\n",
538
+ "\n",
539
+ "\"\"\"\n",
540
+ "# passage_text: a chunk of text.\n",
541
+ "# end_tag: a string, whose value is \"\" if the text is at the end of the article, and otherwise \"\\n...\".\n",
542
+ "\n",
543
+ "\n",
544
+ "\n",
545
+ "# Gisting\n",
546
+ "gisting_prompt_template = \"\"\"\n",
547
+ "Please shorten the following passage.\n",
548
+ "Just give me a shortened version. DO NOT explain your reason.\n",
549
+ "\n",
550
+ "Passage:\n",
551
+ "{page_text}\n",
552
+ "\n",
553
+ "\"\"\"\n",
554
+ "# page_text: a page of text\n",
555
+ "\n",
556
+ "\n",
557
+ "\n",
558
+ "# Parallel Look-up (ReadAgent-P, up to 5 pages)\n",
559
+ "parallel_lookup_prompt_template = \"\"\"\n",
560
+ "The following text is what you remembered from reading an article and a multiple choice question related to it.\n",
561
+ "You may read 1 to 5 page(s) of the article again to refresh your memory to prepare yourselve for the question.\n",
562
+ "Please respond with which page(s) you would like to read again.\n",
563
+ "For example, if your would like to only read Page 8, respond with \\\"I want to look up Page [8] to ...\\\";\n",
564
+ "if your would like to read Page 7 and 12, respond with \\\"I want to look up Page [7, 12] to ...\\\";\n",
565
+ "if your would like to read Page 2, 3, 7, 15 and 18, respond with \\\"I want to look up Page [2, 3, 7, 15, 18] to ...\\\".\n",
566
+ "DO NOT select more pages if you don't need to.\n",
567
+ "DO NOT answer the question yet.\n",
568
+ "\n",
569
+ "Text:\n",
570
+ "{concatenated_gists}\n",
571
+ "\n",
572
+ "Question:\n",
573
+ "{question}\n",
574
+ "{options}\n",
575
+ "\n",
576
+ "Take a deep breath and tell me: Which page(s) would you like to read again?\n",
577
+ "\"\"\"\n",
578
+ "# concatenated_gists: concatenated gists\n",
579
+ "# question: a question\n",
580
+ "# options: multiple-choice options\n",
581
+ "\n",
582
+ "\n",
583
+ "\n",
584
+ "# Sequential Look-up (ReadAgent-S, up to 5 pages)\n",
585
+ "sequential_lookup_prompt_template = \"\"\"\n",
586
+ "The following text is what you remember from reading an article, followed by a question about the article.\n",
587
+ "You may read multiple pages of the article again to refresh your memory and prepare to answer the question.\n",
588
+ "Each page that you re-read can significantly improve your chance of answering the question correctly.\n",
589
+ "Please specify a SINGLE page you would like to read again or say \"STOP\".\n",
590
+ "To read a page again, respond with \"Page $PAGE_NUM\", replacing $PAGE_NUM with the target page number.\n",
591
+ "You can only specify a SINGLE page in your response at this time.\n",
592
+ "DO NOT select more pages if you don't need to.\n",
593
+ "To stop, simply say \"STOP\".\n",
594
+ "DO NOT answer the question in your response.\n",
595
+ "\n",
596
+ "Text:\n",
597
+ "{concatenated_gists}\n",
598
+ "End of text.\n",
599
+ "\n",
600
+ "Pages re-read already (DO NOT ask to read them again):\n",
601
+ "{past_page_numbers}\n",
602
+ "\n",
603
+ "Question:\n",
604
+ "{question}\n",
605
+ "{options}\n",
606
+ "\n",
607
+ "Specify a SINGLE page to read again, or say STOP:\n",
608
+ "\"\"\"\n",
609
+ "# concatenated_gists: concatenated gists\n",
610
+ "# past_page_numbers: page numbers that have already been retrieved\n",
611
+ "# question: a question\n",
612
+ "# options: options\n",
613
+ "\n",
614
+ "\n",
615
+ "\n",
616
+ "# Response/Answer\n",
617
+ "answer_prompt_template = \"\"\"\n",
618
+ "Read the following article and answer a multiple choice question.\n",
619
+ "For example, if (C) is correct, answer with \\\"Answer: (C) ...\\\"\n",
620
+ "\n",
621
+ "Article:\n",
622
+ "{concatenated_pages_and_gists}\n",
623
+ "\n",
624
+ "Question:\n",
625
+ "{question}\n",
626
+ "{options}\n",
627
+ "\n",
628
+ "\"\"\"\n",
629
+ "# concatenated_pages_and_gists: concatenated raw pages and gists\n",
630
+ "# question: a question\n",
631
+ "# options: options"
632
+ ],
633
+ "metadata": {
634
+ "id": "PGvYRIpO3J3Y"
635
+ },
636
+ "execution_count": null,
637
+ "outputs": []
638
+ },
639
+ {
640
+ "cell_type": "code",
641
+ "source": [
642
+ "# @title The prompts we used for QMSum with PaLM 2-L\n",
643
+ "\n",
644
+ "\n",
645
+ "# Pagination\n",
646
+ "pagination_prompt_template = \"\"\"\n",
647
+ "You are given a passage that is taken from a larger meeting transcript.\n",
648
+ "There are some numbered labels between the paragraphs (like <0>) in the passage.\n",
649
+ "Please choose one label at a natural transition in the passage.\n",
650
+ "For example, the label can be at the end of a dialogue, the end of an argument, a change in the topic being discussed, etc.\n",
651
+ "Please respond with the label and explain your choice.\n",
652
+ "For example, if <57> is a natural transition, answer with \"Label: <57>\\n Because ...\"\n",
653
+ "\n",
654
+ "Passage:\n",
655
+ "\n",
656
+ "{preceding_text}\n",
657
+ "{passage_text}\n",
658
+ "{end_tag}\n",
659
+ "\n",
660
+ "\"\"\"\n",
661
+ "# preceding_text: a fraction of previous context\n",
662
+ "# passage_text: a chunk of text.\n",
663
+ "# end_tag: a string, whose value is \"\" if the text is at the end of the article, and otherwise \"\\n...\".\n",
664
+ "\n",
665
+ "\n",
666
+ "\n",
667
+ "# Gisting\n",
668
+ "gisting_prompt_template = \"\"\"\n",
669
+ "Please shorten the following passage.\n",
670
+ "Just give a shortened version. DO NOT explain your reasoning.\n",
671
+ "\n",
672
+ "Passage:\n",
673
+ "{page_text}\n",
674
+ "\n",
675
+ "\"\"\"\n",
676
+ "# page_text: a page of text\n",
677
+ "\n",
678
+ "\n",
679
+ "\n",
680
+ "# Parallel Look-up (ReadAgent-P, up to 2 pages)\n",
681
+ "parallel_lookup_prompt_template = \"\"\"\n",
682
+ "The following text is what you remember from reading a meeting transcript, followed by a question about the transcript.\n",
683
+ "You may read 1 or 2 pages of the transcript again to refresh your memory to prepare to answer the question.\n",
684
+ "Please respond with which page(s) you would like to read.\n",
685
+ "For example, if your would only like to read Page 8, respond with \"I want to look up Page [8] ...\"\n",
686
+ "If you would like to read Page 7 and 12, respond with \"I want to look up Page [7, 12] ...\".\n",
687
+ "Only select as many pages as you need, but no more than 2 pages.\n",
688
+ "Don't answer the question yet.\n",
689
+ "\n",
690
+ "Text:\n",
691
+ "{text}\n",
692
+ "End of text.\n",
693
+ "\n",
694
+ "Question:\n",
695
+ "{question}\n",
696
+ "\n",
697
+ "Which page(s) would you like to look up?\n",
698
+ "\"\"\"\n",
699
+ "# concatenated_gists: Concatenated gists\n",
700
+ "# question: a question\n",
701
+ "\n",
702
+ "\n",
703
+ "\n",
704
+ "# Sequential Look-up (ReadAgent-S)\n",
705
+ "sequential_lookup_prompt_template = \"\"\"\n",
706
+ "The following text is what you remember from reading a meeting transcript, followed by a question about the transcript.\n",
707
+ "You may read multiple pages of the transcript again to refresh your memory and prepare to answer the question.\n",
708
+ "Each page that you re-read can significantly improve your chance of answering the question correctly.\n",
709
+ "Please specify a SINGLE page you would like to read again or say \"STOP\".\n",
710
+ "To read a page again, respond with \"Page $PAGE_NUM\", replacing $PAGE_NUM with the target page number.\n",
711
+ "You can only specify a SINGLE page in your response at this time.\n",
712
+ "DO NOT select more pages if you don't need to.\n",
713
+ "To stop, simply say \"STOP\".\n",
714
+ "DO NOT answer the question in your response.\n",
715
+ "\n",
716
+ "Text:\n",
717
+ "{concatenated_gists}\n",
718
+ "End of text.\n",
719
+ "\n",
720
+ "Pages re-read already (DO NOT ask to read them again):\n",
721
+ "{past_page_numbers}\n",
722
+ "\n",
723
+ "Question:\n",
724
+ "{question}\n",
725
+ "\n",
726
+ "Specify a SINGLE page to read again, or say STOP:\n",
727
+ "\"\"\"\n",
728
+ "# concatenated_gists: concatenated gists\n",
729
+ "# past_page_numbers: page numbers that have already been retrieved\n",
730
+ "# question: a question\n",
731
+ "\n",
732
+ "\n",
733
+ "\n",
734
+ "# Response/Answer\n",
735
+ "answer_prompt_template = \"\"\"\n",
736
+ "Read the question and text below and then answer the question.\n",
737
+ "\n",
738
+ "Question:\n",
739
+ "{question}\n",
740
+ "\n",
741
+ "Text:\n",
742
+ "{concatenated_pages_and_gists}\n",
743
+ "End of Text.\n",
744
+ "\n",
745
+ "Answer the question based on the above passage and retrieved pages. Your answer should be short and concise.\n",
746
+ "\"\"\"\n",
747
+ "# question: a question\n",
748
+ "# concatenated_pages_and_gists: concatenated raw pages and gists"
749
+ ],
750
+ "metadata": {
751
+ "id": "Ya48p13EhhhI"
752
+ },
753
+ "execution_count": null,
754
+ "outputs": []
755
+ },
756
+ {
757
+ "cell_type": "code",
758
+ "source": [
759
+ "# @title The prompts we used for NarrativeQA - Gutenburg with PaLM 2-L\n",
760
+ "\n",
761
+ "\n",
762
+ "# Pagination\n",
763
+ "pagination_prompt_template = \"\"\"\n",
764
+ "You are given a passage that is taken from a larger text (article, book, ...) and some numbered labels between the paragraphs in the passage.\n",
765
+ "Numbered label are in angeled brackets. For example, if the label number is 19, it shows as <19> in text.\n",
766
+ "Please choose one label that marks a major section break point.\n",
767
+ "Such points can be the beginning/end of a book, beginning/end of a chapter, end of a content table, a scene transition, end of a dialogue, etc.\n",
768
+ "If a point is chosen for the beginning of a book/chapter/etc and there is a title of the new book/chapter/etc, the break point must be chosen at a position right before the section number and title, not after.\n",
769
+ "\n",
770
+ "Please answer the break point label and explain.\n",
771
+ "For example, if <57> is a good point to break, answer with \\\"Breakpoint: <57> ...\\\"\n",
772
+ "\n",
773
+ "Text:\n",
774
+ "\n",
775
+ "{preceding_text}\n",
776
+ "{passage_text}\n",
777
+ "{end_tag}\n",
778
+ "\n",
779
+ "\"\"\"\n",
780
+ "# preceding_text: a fraction of previous context\n",
781
+ "# passage_text: a chunk of text.\n",
782
+ "# end_tag: a string, whose value is \"\" if the text is at the end of the article, and otherwise \"\\n...\".\n",
783
+ "\n",
784
+ "\n",
785
+ "\n",
786
+ "# Gisting\n",
787
+ "gisting_prompt_template = \"\"\"\n",
788
+ "Please shorten the following passage.\n",
789
+ "Just give me a shortened version. DO NOT explain your reason.\n",
790
+ "\n",
791
+ "Passage:\n",
792
+ "{page_text}\n",
793
+ "\n",
794
+ "\"\"\"\n",
795
+ "# page_text: a page of text\n",
796
+ "\n",
797
+ "\n",
798
+ "\n",
799
+ "# Parallel Look-up (ReadAgent-P, up to 2 pages)\n",
800
+ "parallel_lookup_prompt_template = \"\"\"\n",
801
+ "The following text is what you remembered from reading an article and a question related to it.\n",
802
+ "You may read 1 or 2 page(s) of the article again to refresh your memory to prepare yourselve for the question.\n",
803
+ "Please respond with which page(s) you would like to read in the order of importance, beginning with the most important page number.\n",
804
+ "For example, if your only need to read Page 8, respond with \\\"I want to look up Page [8] to ...\\\";\n",
805
+ "if your would like to read Page 12 and 7, respond with \\\"I want to look up Page [12, 7] to ...\\\";\n",
806
+ "DO NOT select more pages if you don't need to.\n",
807
+ "You don't need to answer the question yet.\n",
808
+ "\n",
809
+ "Text:\n",
810
+ "{concatenated_gists}\n",
811
+ "\n",
812
+ "Question:\n",
813
+ "{question}\n",
814
+ "\n",
815
+ "\"\"\"\n",
816
+ "# concatenated_gists: Concatenated gists\n",
817
+ "# question: a question\n",
818
+ "\n",
819
+ "\n",
820
+ "\n",
821
+ "# Sequential Look-up (ReadAgent-S)\n",
822
+ "sequential_lookup_prompt_template = \"\"\"\n",
823
+ "The following text is what you remember from reading a meeting transcript, followed by a question about the transcript.\n",
824
+ "You may read multiple pages of the transcript again to refresh your memory and prepare to answer the question.\n",
825
+ "Each page that you re-read can significantly improve your chance of answering the question correctly.\n",
826
+ "Please specify a SINGLE page you would like to read again or say \"STOP\".\n",
827
+ "To read a page again, respond with \"Page $PAGE_NUM\", replacing $PAGE_NUM with the target page number.\n",
828
+ "You can only specify a SINGLE page in your response at this time.\n",
829
+ "DO NOT select more pages if you don't need to.\n",
830
+ "To stop, simply say \"STOP\".\n",
831
+ "DO NOT answer the question in your response.\n",
832
+ "\n",
833
+ "Text:\n",
834
+ "{concatenated_gists}\n",
835
+ "End of text.\n",
836
+ "\n",
837
+ "Pages re-read already (DO NOT ask to read them again):\n",
838
+ "{past_page_numbers}\n",
839
+ "\n",
840
+ "Question:\n",
841
+ "{question}\n",
842
+ "\n",
843
+ "Specify a SINGLE page to read again, or say STOP:\n",
844
+ "\"\"\"\n",
845
+ "# concatenated_gists: concatenated gists\n",
846
+ "# past_page_numbers: page numbers that have already been retrieved\n",
847
+ "# question: a question\n",
848
+ "\n",
849
+ "\n",
850
+ "\n",
851
+ "# Response/Answer\n",
852
+ "answer_prompt_template = \"\"\"\n",
853
+ "{concatenated_pages_and_gists}\n",
854
+ "\n",
855
+ "Question:\n",
856
+ "{question}\n",
857
+ "\n",
858
+ "Answer the question based on the above passage and retrieved pages. Your answer should be short and concise.\n",
859
+ "\"\"\"\n",
860
+ "# concatenated_pages_and_gists: concatenated raw pages and gists\n",
861
+ "# question: a question"
862
+ ],
863
+ "metadata": {
864
+ "id": "U210HZJuP4_u"
865
+ },
866
+ "execution_count": null,
867
+ "outputs": []
868
+ },
869
+ {
870
+ "cell_type": "code",
871
+ "source": [
872
+ "# @title The prompts we used for NarrativeQA - Movie Scripts with PaLM 2-L\n",
873
+ "\n",
874
+ "\n",
875
+ "# Pagination\n",
876
+ "pagination_prompt_template = \"\"\"\n",
877
+ "You are given a movie script and some numbered labels between the lines in the script.\n",
878
+ "Numbered label are in angeled brackets.\n",
879
+ "Please choose one label that it is natural to break reading. The label should be between <{start}> and <{end}>.\n",
880
+ "Such point can be scene transition, end of a dialogue, end of an argument, narrative transition, etc.\n",
881
+ "The answer should end with \"The break point is: <number>\", where the break point number is between angeled brackets.\n",
882
+ "\n",
883
+ "Script:\n",
884
+ "\n",
885
+ "{passage_text}\n",
886
+ "\"\"\"\n",
887
+ "# passage_text: a chunk of text.\n",
888
+ "\n",
889
+ "\n",
890
+ "\n",
891
+ "# Gisting\n",
892
+ "gisting_prompt_template = \"\"\"\n",
893
+ "Please shorten the following passage. The shortened passage should be in 128 tokens. Please refer to people with their full names whenever possible.\n",
894
+ "Just give me a shortened version. DO NOT explain your reason. If there is no meaning information in the passage, output \"I don't have enough information to shorten the passage.\"\n",
895
+ "\n",
896
+ "Passage:\n",
897
+ "{page_text}\n",
898
+ "\n",
899
+ "\"\"\"\n",
900
+ "# page_text: a page of text\n",
901
+ "\n",
902
+ "\n",
903
+ "\n",
904
+ "# Parallel Look-up (ReadAgent-P, up to 2 pages)\n",
905
+ "parallel_lookup_prompt_template = \"\"\"\n",
906
+ "The following text includes a summary of each page in a movie script, followed by a question about the script.\n",
907
+ "\n",
908
+ "Summary:\n",
909
+ "{concatenated_gists}\n",
910
+ "\n",
911
+ "Question:\n",
912
+ "{question}\n",
913
+ "\n",
914
+ "Based on the summary of each page, you may read the full details of 1 to 2 page(s) to obtain more information to answer the question.\n",
915
+ "Please respond with which page(s) you would like to read.\n",
916
+ "For example, if you only need to read Page X_0, the answer should end with \\\"I want to look up Page [X_0]\\\";\n",
917
+ "if you would like to read Page X_0 and X_1, the answer should end with \\\"I want to look up Page [X_0, X_1]\\\".\n",
918
+ "X_i above is a page index between 0 and {end}. DO NOT select more pages if you don't need to.\n",
919
+ "You don't need to answer the question yet.\n",
920
+ "\"\"\"\n",
921
+ "# concatenated_gists: Concatenated gists\n",
922
+ "# question: a question\n",
923
+ "\n",
924
+ "\n",
925
+ "\n",
926
+ "# Sequential Look-up (ReadAgent-S)\n",
927
+ "sequential_lookup_prompt_template = \"\"\"\n",
928
+ "The following text includes a summary of each page in a movie script, followed by a question about the script, and the previous answer based on the summary and already re-read pages.\n",
929
+ "\n",
930
+ "Summary:\n",
931
+ "{concatenated_gists}\n",
932
+ "\n",
933
+ "Pages re-read already (DO NOT ask to read them again):\n",
934
+ "{past_page_numbers}\n",
935
+ "\n",
936
+ "Question:\n",
937
+ "{question}\n",
938
+ "\n",
939
+ "Previous Answer:\n",
940
+ "{previous_answer}\n",
941
+ "\n",
942
+ "Based on the summary of each page, you may read the full details of multiple pages to obtain more information to answer the question.\n",
943
+ "To read a page again, respond with \"I want to look up Page $PAGE_NUM\", replacing $PAGE_NUM with the target page number.\n",
944
+ "PAGE_NUM is a page index between 0 and {end}, excluding {pages_reread}.\n",
945
+ "You can only specify a SINGLE page in your response at this time. The page should not be in the re-read pages.\n",
946
+ "To stop, simply say \"STOP\".\n",
947
+ "DO NOT answer the question in your response.\n",
948
+ "You don't need to answer the question yet.\n",
949
+ "\"\"\"\n",
950
+ "# concatenated_gists: concatenated gists\n",
951
+ "# past_page_numbers: (only after the first query) page numbers that have already been retrieved\n",
952
+ "# question: a question\n",
953
+ "# previous_answer: the previous answer given by the model with gists and previously retrieved raw pages\n",
954
+ "\n",
955
+ "\n",
956
+ "\n",
957
+ "# Response/Answer\n",
958
+ "answer_prompt_template = \"\"\"\n",
959
+ "{concatenated_pages_and_gists}\n",
960
+ "\n",
961
+ "Question:\n",
962
+ "{question}\n",
963
+ "\n",
964
+ "Answer the question based on the above passage and retrieved pages. Your answer should be short and concise.\n",
965
+ "\"\"\"\n",
966
+ "# concatenated_pages_and_gists: concatenated raw pages and gists\n",
967
+ "# question: a question"
968
+ ],
969
+ "metadata": {
970
+ "id": "K1pw1NPasHPC"
971
+ },
972
+ "execution_count": null,
973
+ "outputs": []
974
+ }
975
+ ]
976
+ }
baselines/read-agent/run_readagent_baseline.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ReadAgent baseline for the EvolV-Mem benchmark.
3
+
4
+ Adapts ReadAgent's 3-step pipeline for EvolV-Mem:
5
+ 1. Pagination: each session = one page (natural segmentation, no LLM needed)
6
+ 2. Gisting: use pre-computed session summaries (no LLM needed)
7
+ 3. Look-up + Answer: LLM reads gists, selects pages to expand, then answers
8
+
9
+ Since ~1000 sessions don't fit in one prompt as gists, we pre-filter with SBert
10
+ to top-N sessions, then apply ReadAgent's look-up on those.
11
+
12
+ Usage:
13
+ python baselines/read-agent.github.io/run_readagent_baseline.py \
14
+ --in_file dataset/evolv_mem_v4.json \
15
+ --out_file output/readagent_qwen30b_v4.jsonl \
16
+ --summary_file dataset/all_session_summary.json \
17
+ --sessions_file dataset/all_sessions.json \
18
+ --profile_file metadata/generated_user_profile.json
19
+
20
+ Env vars:
21
+ VLLM_BASE_URL (default http://localhost:8000/v1)
22
+ VLLM_API_KEY (default EMPTY)
23
+ """
24
+
25
+ import argparse
26
+ import json
27
+ import logging
28
+ import os
29
+ import re
30
+ import sys
31
+ import time
32
+ from collections import defaultdict
33
+ from typing import Dict, List, Optional
34
+
35
+ import numpy as np
36
+ from tqdm import tqdm
37
+
38
+ logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # vLLM LLM helper
42
+ # ---------------------------------------------------------------------------
43
+
44
+ MODEL_NAME = os.getenv("VLLM_MODEL_NAME", "Qwen/Qwen3-30B-A3B-Instruct-2507")
45
+
46
+
47
+ def get_llm_client():
48
+ from openai import OpenAI
49
+ return OpenAI(
50
+ base_url=os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1"),
51
+ api_key=os.getenv("VLLM_API_KEY", "EMPTY"),
52
+ )
53
+
54
+
55
+ def llm_call(client, prompt: str, max_tokens: int = 4096, temperature: float = 0.0) -> str:
56
+ for attempt in range(6):
57
+ try:
58
+ response = client.chat.completions.create(
59
+ model=MODEL_NAME,
60
+ messages=[{"role": "user", "content": prompt}],
61
+ max_tokens=max_tokens,
62
+ temperature=temperature,
63
+ )
64
+ content = response.choices[0].message.content if response.choices else None
65
+ if content is None:
66
+ # NVIDIA endpoint occasionally returns null content; retry once with same prompt
67
+ wait = min(2 ** attempt * 2, 30)
68
+ print(f"[WARN] LLM returned None content (attempt {attempt+1}); retrying in {wait}s")
69
+ time.sleep(wait)
70
+ continue
71
+ return content.strip()
72
+ except Exception as e:
73
+ msg = str(e).lower()
74
+ if any(code in msg for code in ("429", "500", "503", "rate limit")):
75
+ wait = min(2 ** attempt * 5, 60)
76
+ print(f"[WARN] LLM retry {attempt+1}/6, sleeping {wait}s: {e}")
77
+ time.sleep(wait)
78
+ continue
79
+ print(f"[ERROR] LLM call failed: {e}")
80
+ raise
81
+ raise RuntimeError("LLM call failed after 6 retries")
82
+
83
+
84
+ # ---------------------------------------------------------------------------
85
+ # ReadAgent prompts (adapted from the notebook for chat-history QA)
86
+ # ---------------------------------------------------------------------------
87
+
88
+ PROMPT_LOOKUP = """The following text contains gists (short summaries) of pages from a user's chat history. Each page is one conversation session.
89
+ You are also given a question about the user's chat history.
90
+
91
+ You may select up to {max_pages} page(s) to read in full to help answer the question.
92
+ Please respond with which page(s) you would like to read.
93
+ For example, if you only need to read Page 3, respond with "I want to look up Page [3] to ...";
94
+ if you would like to read Page 2 and 7, respond with "I want to look up Page [2, 7] to ...".
95
+ DO NOT select more pages than necessary.
96
+ DO NOT answer the question yet.
97
+
98
+ Text:
99
+ {concatenated_gists}
100
+
101
+ Question:
102
+ {question}
103
+
104
+ Current Date: {question_date}
105
+
106
+ Take a deep breath and tell me: Which page(s) would you like to read again?
107
+ """
108
+
109
+ PROMPT_ANSWER = """Read the following chat history and answer the question.
110
+
111
+ {profile_section}
112
+
113
+ Chat History:
114
+ {concatenated_pages_and_gists}
115
+
116
+ Current Date: {question_date}
117
+ Question: {question}
118
+ Answer:"""
119
+
120
+
121
+ # ---------------------------------------------------------------------------
122
+ # Embedding pre-filter
123
+ # ---------------------------------------------------------------------------
124
+
125
+ def embed_and_filter(
126
+ question: str,
127
+ session_ids: List[str],
128
+ session_gists: List[str],
129
+ embedding_model,
130
+ top_k: int = 100,
131
+ ) -> List[int]:
132
+ """Return indices of top-k most relevant sessions by cosine similarity."""
133
+ if len(session_ids) <= top_k:
134
+ return list(range(len(session_ids)))
135
+
136
+ question_emb = embedding_model.encode(question)
137
+ gist_embs = embedding_model.encode(session_gists)
138
+
139
+ question_norm = question_emb / (np.linalg.norm(question_emb) + 1e-10)
140
+ gist_norms = gist_embs / (np.linalg.norm(gist_embs, axis=1, keepdims=True) + 1e-10)
141
+ similarities = gist_norms @ question_norm
142
+
143
+ top_indices = np.argsort(similarities)[::-1][:top_k].tolist()
144
+ # Sort by original order (chronological) for coherent reading
145
+ top_indices.sort()
146
+ return top_indices
147
+
148
+
149
+ # ---------------------------------------------------------------------------
150
+ # ReadAgent Look-up: parse page selection from LLM response
151
+ # ---------------------------------------------------------------------------
152
+
153
+ def parse_lookup_pages(response: str, max_page: int) -> List[int]:
154
+ """Parse page indices from ReadAgent look-up response like 'Page [2, 7, 12]'."""
155
+ try:
156
+ start = response.index('[')
157
+ end = response.index(']')
158
+ except ValueError:
159
+ return []
160
+
161
+ page_ids = []
162
+ for p in response[start + 1:end].split(','):
163
+ p = p.strip()
164
+ if p.isnumeric():
165
+ pid = int(p)
166
+ if 0 <= pid < max_page:
167
+ page_ids.append(pid)
168
+ return page_ids
169
+
170
+
171
+ # ---------------------------------------------------------------------------
172
+ # Retrieval metrics
173
+ # ---------------------------------------------------------------------------
174
+
175
+ def evaluate_retrieval(recalled_docs, correct_docs):
176
+ recall_any = float(any(doc in recalled_docs for doc in correct_docs))
177
+ recall_all = float(all(doc in recalled_docs for doc in correct_docs))
178
+ return recall_any, recall_all
179
+
180
+
181
+ def print_average_metrics(retrieval_metric_list):
182
+ metric_sums = defaultdict(float)
183
+ metric_counts = defaultdict(int)
184
+ for metric in retrieval_metric_list:
185
+ for k, v in metric.items():
186
+ metric_sums[k] += v
187
+ metric_counts[k] += 1
188
+ print(" Average retrieval metrics:")
189
+ for k in sorted(metric_sums):
190
+ avg = metric_sums[k] / metric_counts[k]
191
+ print(f" {k}: {avg:.4f}")
192
+
193
+
194
+ # ---------------------------------------------------------------------------
195
+ # Main
196
+ # ---------------------------------------------------------------------------
197
+
198
+ def main():
199
+ parser = argparse.ArgumentParser(description="ReadAgent baseline for EvolV-Mem")
200
+ parser.add_argument("--in_file", type=str, required=True,
201
+ help="Path to evolv_mem_v4.json")
202
+ parser.add_argument("--out_file", type=str, required=True,
203
+ help="Output JSONL file")
204
+ parser.add_argument("--summary_file", type=str, required=True,
205
+ help="Path to all_session_summary.json")
206
+ parser.add_argument("--sessions_file", type=str, required=True,
207
+ help="Path to all_sessions.json")
208
+ parser.add_argument("--profile_file", type=str, default=None,
209
+ help="Path to generated_user_profile.json")
210
+ parser.add_argument("--embedding_model", type=str,
211
+ default="sentence-transformers/multi-qa-mpnet-base-cos-v1",
212
+ help="SentenceTransformer model for pre-filtering")
213
+ # ReadAgent params
214
+ parser.add_argument("--gist_top_k", type=int, default=100,
215
+ help="Number of sessions to keep after embedding pre-filter (default 100)")
216
+ parser.add_argument("--max_lookup_pages", type=int, default=10,
217
+ help="Max pages the LLM can select for full reading (default 10)")
218
+ # Limit
219
+ parser.add_argument("--limit", type=int, default=None,
220
+ help="Process only the first N questions")
221
+ args = parser.parse_args()
222
+
223
+ # -----------------------------------------------------------------------
224
+ # Load data
225
+ # -----------------------------------------------------------------------
226
+ print(f"Loading benchmark from {args.in_file} ...")
227
+ with open(args.in_file) as f:
228
+ benchmark = json.load(f)
229
+ if args.limit:
230
+ benchmark = benchmark[:args.limit]
231
+ print(f" {len(benchmark)} questions loaded.")
232
+
233
+ print(f"Loading session summaries from {args.summary_file} ...")
234
+ with open(args.summary_file) as f:
235
+ summaries = json.load(f)
236
+ print(f" {len(summaries)} session summaries loaded.")
237
+
238
+ print(f"Loading sessions from {args.sessions_file} ...")
239
+ with open(args.sessions_file) as f:
240
+ all_sessions = json.load(f)
241
+ print(f" {len(all_sessions)} sessions loaded.")
242
+
243
+ profiles = {}
244
+ if args.profile_file and os.path.exists(args.profile_file):
245
+ print(f"Loading user profiles from {args.profile_file} ...")
246
+ with open(args.profile_file) as f:
247
+ profiles = json.load(f)
248
+ print(f" {len(profiles)} profiles loaded.")
249
+
250
+ # -----------------------------------------------------------------------
251
+ # Resume support
252
+ # -----------------------------------------------------------------------
253
+ existing_qids = set()
254
+ if os.path.exists(args.out_file):
255
+ with open(args.out_file) as f:
256
+ for line in f:
257
+ line = line.strip()
258
+ if line:
259
+ existing_qids.add(json.loads(line)["question_id"])
260
+ print(f" Resuming: {len(existing_qids)} questions already processed.")
261
+
262
+ # -----------------------------------------------------------------------
263
+ # Initialize models
264
+ # -----------------------------------------------------------------------
265
+ print("Initializing embedding model ...")
266
+ from sentence_transformers import SentenceTransformer
267
+ embedding_model = SentenceTransformer(args.embedding_model)
268
+
269
+ print("Initializing vLLM client ...")
270
+ client = get_llm_client()
271
+
272
+ # -----------------------------------------------------------------------
273
+ # Process questions
274
+ # -----------------------------------------------------------------------
275
+ retrieval_metric_list = []
276
+ out_f = open(args.out_file, "a")
277
+
278
+ for di, entry in enumerate(tqdm(benchmark, desc="ReadAgent baseline")):
279
+ qid = entry["question_id"]
280
+ question = entry["question"]
281
+ question_date = entry["question_date"]
282
+
283
+ if qid in existing_qids:
284
+ continue
285
+
286
+ try:
287
+ haystack_sids = entry["haystack_session_ids"]
288
+ haystack_dates = entry["haystack_dates"]
289
+
290
+ # === Step 1 & 2: Pagination + Gisting (free with cached data) ===
291
+ # Each session = one page; session summary = gist
292
+ page_sids = []
293
+ page_dates = []
294
+ page_gists = []
295
+ for sid, date_str in zip(haystack_sids, haystack_dates):
296
+ summary_data = summaries.get(sid)
297
+ if summary_data is None:
298
+ continue
299
+ text = summary_data.get("session_summary", "")
300
+ if not text:
301
+ turn_sums = summary_data.get("turn_summaries", [])
302
+ text = " ".join(turn_sums) if turn_sums else ""
303
+ if not text:
304
+ continue
305
+ page_sids.append(sid)
306
+ page_dates.append(date_str)
307
+ page_gists.append(text)
308
+
309
+ if not page_gists:
310
+ result = {
311
+ "q_idx": di, "question_id": qid,
312
+ "hypothesis": "Insufficient information to answer.",
313
+ "n_pages": 0,
314
+ }
315
+ print(json.dumps(result), file=out_f, flush=True)
316
+ continue
317
+
318
+ # === Embedding pre-filter to top-K gists ===
319
+ filtered_indices = embed_and_filter(
320
+ question, page_sids, page_gists, embedding_model, top_k=args.gist_top_k
321
+ )
322
+ filtered_sids = [page_sids[i] for i in filtered_indices]
323
+ filtered_dates = [page_dates[i] for i in filtered_indices]
324
+ filtered_gists = [page_gists[i] for i in filtered_indices]
325
+
326
+ # Build gist text with page numbers
327
+ gist_lines = []
328
+ for local_idx, (sid, date, gist) in enumerate(
329
+ zip(filtered_sids, filtered_dates, filtered_gists)
330
+ ):
331
+ gist_lines.append(f"<Page {local_idx}> [Date: {date}]\n{gist}")
332
+ concatenated_gists = "\n\n".join(gist_lines)
333
+
334
+ # === Step 3a: ReadAgent Look-Up ===
335
+ lookup_prompt = PROMPT_LOOKUP.format(
336
+ max_pages=args.max_lookup_pages,
337
+ concatenated_gists=concatenated_gists,
338
+ question=question,
339
+ question_date=question_date,
340
+ )
341
+ lookup_response = llm_call(client, lookup_prompt, max_tokens=4096, temperature=0.0)
342
+ selected_page_ids = parse_lookup_pages(lookup_response, len(filtered_sids))
343
+
344
+ print(f" [{di}] Look-up selected pages: {selected_page_ids} "
345
+ f"(out of {len(filtered_sids)} gists)")
346
+
347
+ # === Step 3b: Expand selected pages with full session content ===
348
+ expanded_lines = []
349
+ retrieved_session_ids = []
350
+ for local_idx, (sid, date, gist) in enumerate(
351
+ zip(filtered_sids, filtered_dates, filtered_gists)
352
+ ):
353
+ if local_idx in selected_page_ids:
354
+ # Expand: use full session content
355
+ session_turns = all_sessions.get(sid, [])
356
+ full_text = "\n".join(
357
+ f"{t.get('role', 'user')}: {t.get('content', '')}"
358
+ for t in session_turns
359
+ )
360
+ expanded_lines.append(
361
+ f"<Page {local_idx}> [Session {sid} | Date: {date}] [FULL]\n{full_text}"
362
+ )
363
+ retrieved_session_ids.append(sid)
364
+ else:
365
+ # Keep gist
366
+ expanded_lines.append(
367
+ f"<Page {local_idx}> [Date: {date}] [GIST]\n{gist}"
368
+ )
369
+ concatenated_pages_and_gists = "\n\n".join(expanded_lines)
370
+
371
+ # === Step 3c: Answer ===
372
+ user_id = qid.split("_q_")[0] if "_q_" in qid else qid
373
+ user_profile = profiles.get(user_id, None)
374
+ profile_section = f"User Profile:\n{user_profile}" if user_profile else ""
375
+
376
+ answer_prompt = PROMPT_ANSWER.format(
377
+ profile_section=profile_section,
378
+ concatenated_pages_and_gists=concatenated_pages_and_gists,
379
+ question=question,
380
+ question_date=question_date,
381
+ )
382
+ answer = llm_call(client, answer_prompt, max_tokens=8192, temperature=0.0)
383
+
384
+ # === Retrieval metrics ===
385
+ answer_session_ids = entry.get("answer_session_ids", [])
386
+ retrieval_metric = {}
387
+ if answer_session_ids and retrieved_session_ids:
388
+ for topk in [5, 10, 20, 30]:
389
+ r_any, r_all = evaluate_retrieval(
390
+ retrieved_session_ids[:topk], answer_session_ids
391
+ )
392
+ retrieval_metric[f"recall_any@{topk}"] = r_any
393
+ retrieval_metric[f"recall_all@{topk}"] = r_all
394
+ retrieval_metric_list.append(retrieval_metric)
395
+ print_average_metrics(retrieval_metric_list)
396
+
397
+ # === Output ===
398
+ result = {
399
+ "q_idx": di,
400
+ "question_id": qid,
401
+ "hypothesis": answer,
402
+ "n_pages_total": len(page_gists),
403
+ "n_pages_filtered": len(filtered_sids),
404
+ "n_pages_expanded": len(selected_page_ids),
405
+ "retrieved_session_ids": retrieved_session_ids,
406
+ "retrieval_metric": retrieval_metric,
407
+ }
408
+ print(json.dumps(result), file=out_f, flush=True)
409
+
410
+ print(f" [{di}] Q: {question[:100]}...")
411
+ print(f" [{di}] A: {answer[:200]}...")
412
+
413
+ except Exception as e:
414
+ print(f"[ERROR] q_idx={di} qid={qid} failed: {e}", flush=True)
415
+ import traceback
416
+ traceback.print_exc()
417
+ continue
418
+
419
+ out_f.close()
420
+ print(f"\nDone. Results saved to {args.out_file}")
421
+
422
+
423
+ if __name__ == "__main__":
424
+ main()
evaluate_qa.py ADDED
@@ -0,0 +1,916 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import re
5
+ import time
6
+
7
+ import numpy as np
8
+ from openai import OpenAI
9
+ from tqdm import tqdm
10
+
11
+ try:
12
+ from openai import AzureOpenAI
13
+ from azure.identity import (
14
+ AzureCliCredential,
15
+ ChainedTokenCredential,
16
+ ManagedIdentityCredential,
17
+ get_bearer_token_provider,
18
+ )
19
+
20
+ AZURE_OAUTH_SCOPE = os.environ.get("AZURE_OAUTH_SCOPE", "")
21
+ if AZURE_OAUTH_SCOPE:
22
+ credential = get_bearer_token_provider(
23
+ ChainedTokenCredential(
24
+ AzureCliCredential(),
25
+ ManagedIdentityCredential(),
26
+ ),
27
+ AZURE_OAUTH_SCOPE,
28
+ )
29
+ else:
30
+ credential = None
31
+ except ImportError:
32
+ AzureOpenAI = None
33
+ credential = None
34
+
35
+ from model_zoo import model_zoo
36
+
37
+
38
+ # Azure OpenAI endpoint (set AZURE_OPENAI_ENDPOINT env var to your deployment URL).
39
+ endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT", "")
40
+ # OpenAI-compatible LiteLLM proxy URL (set LITELLM_BASE_URL env var to your proxy).
41
+ TRITONAI_BASE_URL = os.environ.get("LITELLM_BASE_URL", "")
42
+
43
+ ATOMIC_PROMPT_VERSION = "atomic-v1"
44
+ LEGACY_PROMPT_VERSION = "binary-v0"
45
+ ATOM_SCORES = {
46
+ "correct": 1.0,
47
+ "partially_correct": 0.5,
48
+ "missing": 0.0,
49
+ "incorrect": 0.0,
50
+ }
51
+
52
+
53
+ def _retryable_status(e) -> "int | None":
54
+ status = getattr(e, "status_code", None) or getattr(e, "http_status", None)
55
+ if status in (429, 500, 503, 403):
56
+ return status
57
+ resp = getattr(e, "response", None)
58
+ if resp is not None and getattr(resp, "status_code", None) in (429, 500, 503):
59
+ return resp.status_code
60
+ msg = str(e).lower()
61
+ if "429" in msg or "rate limit" in msg:
62
+ return 429
63
+ if "500" in msg or "internal server error" in msg:
64
+ return 500
65
+ if "503" in msg or "api configuration unavailable" in msg:
66
+ return 503
67
+ return None
68
+
69
+
70
+ def parse_json_object(text):
71
+ text = (text or "").strip()
72
+ if text.startswith("```"):
73
+ text = re.sub(r"^```(?:json)?", "", text).strip()
74
+ text = re.sub(r"```$", "", text).strip()
75
+ try:
76
+ return json.loads(text)
77
+ except json.JSONDecodeError:
78
+ start = text.find("{")
79
+ end = text.rfind("}") + 1
80
+ if start >= 0 and end > start:
81
+ return json.loads(text[start:end])
82
+ raise
83
+
84
+
85
+ def sanitize_model_name(name):
86
+ return re.sub(r"[^A-Za-z0-9_.-]+", "_", name)
87
+
88
+
89
+ def read_json_or_jsonl(path):
90
+ try:
91
+ with open(path, "r", encoding="utf-8") as f:
92
+ return json.load(f)
93
+ except json.JSONDecodeError:
94
+ with open(path, "r", encoding="utf-8") as f:
95
+ return [json.loads(line) for line in f if line.strip()]
96
+
97
+
98
+ def read_existing_jsonl(path):
99
+ if not path or not os.path.exists(path):
100
+ return {}
101
+ rows = {}
102
+ with open(path, "r", encoding="utf-8") as f:
103
+ for line in f:
104
+ if not line.strip():
105
+ continue
106
+ obj = json.loads(line)
107
+ if "question_id" in obj:
108
+ rows[obj["question_id"]] = obj
109
+ return rows
110
+
111
+
112
+ def write_json(path, obj):
113
+ tmp_path = path + ".tmp"
114
+ with open(tmp_path, "w", encoding="utf-8") as f:
115
+ json.dump(obj, f, ensure_ascii=False, indent=2)
116
+ f.write("\n")
117
+ os.replace(tmp_path, path)
118
+
119
+
120
+ def append_jsonl(path, row):
121
+ with open(path, "a", encoding="utf-8") as f:
122
+ print(json.dumps(row, ensure_ascii=False), file=f, flush=True)
123
+
124
+
125
+ def default_result_file(hyp_file, metric_model, eval_mode):
126
+ if eval_mode == "legacy":
127
+ return f"{hyp_file}.eval-results-{metric_model}"
128
+ model_tag = sanitize_model_name(metric_model)
129
+ return f"{hyp_file}.eval-results-{model_tag}-{ATOMIC_PROMPT_VERSION}-{eval_mode}"
130
+
131
+
132
+ def default_rubric_file(ref_file):
133
+ return f"{ref_file}.{ATOMIC_PROMPT_VERSION}.rubric.json"
134
+
135
+
136
+ def load_rubric_file(path, ref_file):
137
+ if not path or not os.path.exists(path):
138
+ return {
139
+ "prompt_version": ATOMIC_PROMPT_VERSION,
140
+ "source_ref_file": ref_file,
141
+ "rubrics": {},
142
+ }
143
+ with open(path, "r", encoding="utf-8") as f:
144
+ data = json.load(f)
145
+ if "rubrics" in data:
146
+ data.setdefault("prompt_version", ATOMIC_PROMPT_VERSION)
147
+ data.setdefault("source_ref_file", ref_file)
148
+ return data
149
+ return {
150
+ "prompt_version": ATOMIC_PROMPT_VERSION,
151
+ "source_ref_file": ref_file,
152
+ "rubrics": data,
153
+ }
154
+
155
+
156
+ def question_type_guidance(task):
157
+ if task in ["Information Absence"]:
158
+ return (
159
+ "The correct answer is that the information is unavailable, absent, "
160
+ "not yet known, or not supported. A response that gives a concrete "
161
+ "answer instead of abstaining is incorrect."
162
+ )
163
+ if task in ["Aggregation", "single-session-user", "single-session-assistant", "multi-session"]:
164
+ return (
165
+ "Check every requested item, count, list member, and named fact. "
166
+ "Exact counts and required list coverage matter. Extra material "
167
+ "facts that change the answer should be flagged."
168
+ )
169
+ if task in ["Aggregation + Temporal"]:
170
+ return (
171
+ "Check both the aggregated facts and their time/order associations. "
172
+ "An answer can name the right item but still be wrong if the timing, "
173
+ "ordering, before/after relation, or year is wrong."
174
+ )
175
+ if task in ["Temporal Reasoning", "temporal-reasoning"]:
176
+ return (
177
+ "Check the specific time, date, year, sequence, duration, or temporal "
178
+ "relationship asked for. Accept +/-1 only for day/week/month duration "
179
+ "counts, not for years, event identity, or ordering."
180
+ )
181
+ if task in ["Knowledge Update", "knowledge-update"]:
182
+ return (
183
+ "Check the current or most recent state. Historical context is fine "
184
+ "only if the final/current state is clearly correct. Outdated states "
185
+ "presented as current are incorrect."
186
+ )
187
+ if task == "single-session-preference":
188
+ return (
189
+ "Check whether the response recalls and applies the stated preference. "
190
+ "Do not require unnecessary verbosity, but contradictions are incorrect."
191
+ )
192
+ return "Check whether the model response satisfies all required facts in the reference answer."
193
+
194
+
195
+ def get_anscheck_prompt(task, question, answer, response):
196
+ """Legacy binary yes/no prompt kept for backward compatibility."""
197
+ if task in ["Information Absence"]:
198
+ template = """You are evaluating whether a model correctly identified that a question is unanswerable given the available personal chat history.
199
+
200
+ Question: {question}
201
+
202
+ Explanation of why it is unanswerable: {answer}
203
+
204
+ Model Response: {response}
205
+
206
+ Evaluation criteria:
207
+ - CORRECT if the model explicitly states that the information is not available, insufficient, or that the question cannot be answered from the provided context.
208
+ - INCORRECT if the model fabricates an answer or fails to acknowledge the unanswerable nature of the question.
209
+ - The model does not need to use the exact word "unanswerable" -- expressing uncertainty or lack of information is sufficient.
210
+
211
+ Briefly explain your reasoning (1-2 sentences), then on the last line write only: yes or no"""
212
+ elif task in ["Aggregation", "single-session-user", "single-session-assistant", "multi-session"]:
213
+ template = """You are evaluating whether a model correctly answered a question that requires aggregating specific facts from a user's personal chat history.
214
+
215
+ Question: {question}
216
+
217
+ Reference Answer: {answer}
218
+
219
+ Model Response: {response}
220
+
221
+ Evaluation criteria:
222
+ - CORRECT if the response identifies all key items, facts, or counts present in the reference answer, even if phrased differently or with added context.
223
+ - INCORRECT if the response:
224
+ - States a wrong count (e.g., says "5" when the answer is "3")
225
+ - Omits one or more key items/facts listed in the reference
226
+ - Lists mostly wrong items even if the count is right
227
+ - Partial answers that cover only a subset of required items are INCORRECT.
228
+ - Verbose responses are acceptable as long as all reference items are present within them.
229
+ - If the response contains correct items but also lists additional plausible-sounding but unverified items beyond the reference, this does NOT make it incorrect -- evaluate only whether the reference items are covered.
230
+ - Semantic equivalence counts as correct (e.g., "RSI" = "Repetitive Strain Injury").
231
+
232
+ Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no"""
233
+ elif task in ["Aggregation + Temporal"]:
234
+ template = """You are evaluating whether a model correctly answered a question that requires both aggregating facts AND reasoning about their temporal order or time associations from a user's personal chat history.
235
+
236
+ Question: {question}
237
+
238
+ Reference Answer: {answer}
239
+
240
+ Model Response: {response}
241
+
242
+ Evaluation criteria:
243
+ - CORRECT if the response captures both:
244
+ (a) all key events/facts listed in the reference, and
245
+ (b) their correct temporal associations (ordering, time periods, or "before/after" relationships).
246
+ - INCORRECT if the response:
247
+ - Omits one or more key events or facts from the reference
248
+ - Gets the temporal ordering or time associations wrong
249
+ - Captures only the content without the temporal aspects, or vice versa
250
+ - Responses that describe the correct progression/sequence in different words are acceptable.
251
+ - Partial answers covering only some events or ignoring time aspects are INCORRECT.
252
+ - Minor wording differences or additional explanatory context are acceptable.
253
+
254
+ Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no"""
255
+ elif task in ["Temporal Reasoning", "temporal-reasoning"]:
256
+ template = """You are evaluating whether a model correctly answered a question about temporal relationships in a user's personal chat history.
257
+
258
+ Question: {question}
259
+
260
+ Reference Answer: {answer}
261
+
262
+ Model Response: {response}
263
+
264
+ Evaluation criteria:
265
+ - CORRECT if the response correctly identifies the specific time, date, year, sequence, or temporal relationship asked about.
266
+ - INCORRECT if the response states a wrong year, wrong sequence, wrong temporal relationship, or misidentifies which event came first/last.
267
+ - Off-by-one tolerance: if the question asks for a count of days, weeks, or months, accept answers that differ by +/-1. This tolerance does NOT apply to years or to identifying specific events/artifacts.
268
+ - Responses that correctly identify the fact but with verbose context are acceptable.
269
+ - If the response hedges but still states the correct answer, it is correct.
270
+
271
+ Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no"""
272
+ elif task in ["Knowledge Update", "knowledge-update"]:
273
+ template = """You are evaluating whether a model correctly answered a question about the most recent or current state of something that changed over time in a user's personal chat history.
274
+
275
+ Question: {question}
276
+
277
+ Reference Answer: {answer}
278
+
279
+ Model Response: {response}
280
+
281
+ Evaluation criteria:
282
+ - CORRECT if the response correctly identifies the most recent/current state as described in the reference answer.
283
+ - The response may include earlier historical states as context -- this is acceptable as long as the current/final state is correctly identified and clearly stated.
284
+ - INCORRECT if the response:
285
+ - States an outdated/superseded state as the current one
286
+ - Omits the current state entirely
287
+ - Correctly describes history but draws the wrong conclusion about what the current state is
288
+ - Semantic equivalence counts (e.g., "flexitarian" and "semi-vegetarian diet with occasional meat" are equivalent if contextually clear).
289
+
290
+ Briefly explain your reasoning (2-3 sentences), then on the last line write only: yes or no"""
291
+ elif task == "single-session-preference":
292
+ template = """You are evaluating whether a model correctly answered a personalized question based on a user's stated preferences from their chat history.
293
+
294
+ Question: {question}
295
+
296
+ Reference Rubric: {answer}
297
+
298
+ Model Response: {response}
299
+
300
+ Evaluation criteria:
301
+ - CORRECT if the response recalls and applies the user's personal preferences correctly, even if not covering every point in the rubric.
302
+ - INCORRECT if the response ignores, contradicts, or misremembers the user's preferences.
303
+
304
+ Briefly explain your reasoning (1-2 sentences), then on the last line write only: yes or no"""
305
+ else:
306
+ template = """You are evaluating whether a model's response correctly answers a question based on a user's personal chat history.
307
+
308
+ Question: {question}
309
+
310
+ Reference Answer: {answer}
311
+
312
+ Model Response: {response}
313
+
314
+ Is the response correct? It is correct if it contains all key information from the reference answer, even if phrased differently.
315
+
316
+ Briefly explain your reasoning (1-2 sentences), then on the last line write only: yes or no"""
317
+ return template.format(question=question, answer=answer, response=response)
318
+
319
+
320
+ def build_rubric_prompt(task, question, answer):
321
+ return f"""You are creating an atomic grading rubric for an open-ended QA benchmark.
322
+
323
+ Question type: {task}
324
+ Question:
325
+ {question}
326
+
327
+ Reference answer:
328
+ {answer}
329
+
330
+ Question-type guidance:
331
+ {question_type_guidance(task)}
332
+
333
+ Decompose the reference answer into the smallest independently checkable requirements needed to answer the question.
334
+
335
+ Rules:
336
+ - Each atom should be a single required answer unit: an entity, count, date/year, order relation, current-state conclusion, or abstention requirement.
337
+ - If an entity and its temporal relation are inseparable for correctness, keep them in the same atom.
338
+ - For list/count questions, include an atom for the exact count when the question asks "how many", and atoms for each required listed item when the item identities matter.
339
+ - For Information Absence, usually use one atom requiring the response to clearly state that the information is unavailable/insufficient/not discussed, and add a strict note that concrete fabricated answers are wrong.
340
+ - Do not include supporting evidence requirements or session IDs unless the question explicitly asks for them.
341
+ - Weights should normally be 1.0. Use a higher weight only when one atom is clearly the main answer and other atoms are minor.
342
+
343
+ Return JSON only with this schema:
344
+ {{
345
+ "required_atoms": [
346
+ {{
347
+ "id": "a1",
348
+ "requirement": "short, specific grading requirement",
349
+ "weight": 1.0
350
+ }}
351
+ ],
352
+ "strict_notes": ["short note about exactness, ordering, abstention, or hallucination handling"]
353
+ }}"""
354
+
355
+
356
+ def build_atomic_eval_prompt(task, question, answer, response, rubric):
357
+ rubric_str = json.dumps(
358
+ {
359
+ "required_atoms": rubric["required_atoms"],
360
+ "strict_notes": rubric.get("strict_notes", []),
361
+ },
362
+ ensure_ascii=False,
363
+ indent=2,
364
+ )
365
+ return f"""You are an LLM-as-a-judge evaluating one model response against a reference answer.
366
+
367
+ Question type: {task}
368
+ Question:
369
+ {question}
370
+
371
+ Reference answer:
372
+ {answer}
373
+
374
+ Model response:
375
+ {response}
376
+
377
+ Atomic grading rubric:
378
+ {rubric_str}
379
+
380
+ Question-type guidance:
381
+ {question_type_guidance(task)}
382
+
383
+ Judge each atom independently.
384
+
385
+ Atom labels:
386
+ - correct: the response fully satisfies this atom, allowing semantic paraphrase.
387
+ - partially_correct: the response gets the main idea but is incomplete or slightly underspecified. Use this sparingly. Do not use it for wrong counts, wrong years, wrong named entities, wrong ordering, or a concrete answer to an Information Absence question.
388
+ - missing: the response does not address this atom.
389
+ - incorrect: the response contradicts this atom or gives the wrong count/entity/date/order/current state.
390
+
391
+ Also identify unsupported_or_contradictory material:
392
+ - severity "material": extra answer content that changes the final answer, adds extra items to an exact list/count, gives an outdated current state, fabricates a concrete answer for Information Absence, or contradicts any atom.
393
+ - severity "minor": harmless context or extra explanation that does not change the answer.
394
+
395
+ Return JSON only with this schema:
396
+ {{
397
+ "atom_judgments": [
398
+ {{
399
+ "id": "a1",
400
+ "label": "correct|partially_correct|missing|incorrect",
401
+ "rationale": "brief reason"
402
+ }}
403
+ ],
404
+ "unsupported_or_contradictory": [
405
+ {{
406
+ "text": "extra or contradictory claim",
407
+ "severity": "minor|material",
408
+ "rationale": "brief reason"
409
+ }}
410
+ ],
411
+ "absence_mismatch": false,
412
+ "overall_rationale": "one or two sentence summary"
413
+ }}"""
414
+
415
+
416
+ def llm_call(
417
+ deployment_name: str,
418
+ api_version: str,
419
+ _prompt: str,
420
+ debug: bool = False,
421
+ vllm: bool = False,
422
+ tritonai: bool = False,
423
+ nvidia: bool = False,
424
+ ):
425
+ if nvidia:
426
+ client = OpenAI(
427
+ api_key=os.getenv("NV_API_KEY"),
428
+ base_url="https://inference-api.nvidia.com",
429
+ )
430
+ while True:
431
+ try:
432
+ return client.chat.completions.create(
433
+ model=deployment_name,
434
+ messages=[{"role": "user", "content": _prompt}],
435
+ )
436
+ except Exception as e:
437
+ st = _retryable_status(e)
438
+ if st in (429, 500, 503, 403):
439
+ print(f"[WARN] HTTP {st} from NVIDIA API; sleeping 60s then retrying...", flush=True)
440
+ time.sleep(60)
441
+ continue
442
+ print("One exception captured", repr(e), flush=True)
443
+ raise
444
+
445
+ if tritonai:
446
+ client = OpenAI(
447
+ api_key=os.getenv("TRITONAI_API_KEY"),
448
+ base_url=TRITONAI_BASE_URL,
449
+ )
450
+ while True:
451
+ try:
452
+ return client.chat.completions.create(
453
+ model=deployment_name,
454
+ messages=[{"role": "user", "content": _prompt}],
455
+ )
456
+ except Exception as e:
457
+ st = _retryable_status(e)
458
+ if st in (429, 500, 503, 403):
459
+ print(f"[WARN] HTTP {st} from LiteLLM proxy; sleeping 60s then retrying...", flush=True)
460
+ time.sleep(60)
461
+ continue
462
+ print("One exception captured", repr(e), flush=True)
463
+ raise
464
+
465
+ if deployment_name.startswith("claude-"):
466
+ import anthropic
467
+
468
+ client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
469
+ while True:
470
+ try:
471
+ msg = client.messages.create(
472
+ model=deployment_name,
473
+ max_tokens=1024,
474
+ messages=[{"role": "user", "content": _prompt}],
475
+ )
476
+
477
+ class _Choice:
478
+ class _Msg:
479
+ def __init__(self, text):
480
+ self.content = text
481
+
482
+ def __init__(self, text):
483
+ self.message = self._Msg(text)
484
+
485
+ class _Completion:
486
+ def __init__(self, text):
487
+ self.choices = [_Choice(text)]
488
+
489
+ return _Completion(msg.content[0].text)
490
+ except Exception as e:
491
+ st = _retryable_status(e)
492
+ if st in (429, 500, 503, 403):
493
+ print(f"[WARN] HTTP {st} from Anthropic; sleeping 60s then retrying...", flush=True)
494
+ time.sleep(60)
495
+ continue
496
+ print("One exception captured", repr(e), flush=True)
497
+ raise
498
+
499
+ if vllm:
500
+ client = OpenAI(
501
+ base_url=os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1"),
502
+ api_key=os.getenv("VLLM_API_KEY", "EMPTY"),
503
+ )
504
+ elif debug:
505
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
506
+ else:
507
+ client = AzureOpenAI(
508
+ azure_endpoint=endpoint,
509
+ azure_ad_token_provider=credential,
510
+ api_version=api_version,
511
+ )
512
+
513
+ kwargs = {
514
+ "model": deployment_name,
515
+ "messages": [{"role": "system", "content": _prompt}],
516
+ }
517
+ while True:
518
+ try:
519
+ return client.chat.completions.create(**kwargs)
520
+ except Exception as e:
521
+ st = _retryable_status(e)
522
+ if st in (429, 500, 503, 403):
523
+ print(f"[WARN] HTTP {st} from LLM; sleeping 120s then retrying...", flush=True)
524
+ time.sleep(120)
525
+ continue
526
+ print("One exception captured", repr(e), flush=True)
527
+ raise
528
+
529
+
530
+ def call_json_llm(prompt, deployment_name, api_version, args, max_retries=3):
531
+ last_error = None
532
+ for attempt in range(max_retries):
533
+ completion = llm_call(
534
+ deployment_name,
535
+ api_version,
536
+ prompt,
537
+ debug=args.debug,
538
+ vllm=args.vllm,
539
+ tritonai=args.tritonai,
540
+ nvidia=args.nvidia,
541
+ )
542
+ content = completion.choices[0].message.content.strip()
543
+ try:
544
+ return parse_json_object(content), content
545
+ except Exception as e:
546
+ last_error = e
547
+ if attempt < max_retries - 1:
548
+ print(f"[WARN] Failed to parse judge JSON; retrying ({attempt + 1}/{max_retries})", flush=True)
549
+ time.sleep(2)
550
+ raise ValueError(f"Failed to parse JSON response from judge: {last_error}")
551
+
552
+
553
+ def fallback_rubric(qid, task, question, answer):
554
+ return {
555
+ "question_id": qid,
556
+ "question_type": task,
557
+ "question": question,
558
+ "reference_answer": answer,
559
+ "required_atoms": [
560
+ {
561
+ "id": "a1",
562
+ "requirement": f"Response must correctly answer the question according to the reference answer: {answer}",
563
+ "weight": 1.0,
564
+ }
565
+ ],
566
+ "strict_notes": ["Fallback single-atom rubric produced because the generated rubric was invalid."],
567
+ "prompt_version": ATOMIC_PROMPT_VERSION,
568
+ }
569
+
570
+
571
+ def normalize_rubric(qid, task, question, answer, parsed):
572
+ atoms = parsed.get("required_atoms", []) if isinstance(parsed, dict) else []
573
+ norm_atoms = []
574
+ for idx, atom in enumerate(atoms, start=1):
575
+ if not isinstance(atom, dict):
576
+ continue
577
+ requirement = str(atom.get("requirement", "")).strip()
578
+ if not requirement:
579
+ continue
580
+ atom_id = str(atom.get("id", f"a{idx}")).strip() or f"a{idx}"
581
+ try:
582
+ weight = float(atom.get("weight", 1.0))
583
+ except (TypeError, ValueError):
584
+ weight = 1.0
585
+ if weight <= 0:
586
+ weight = 1.0
587
+ norm_atoms.append({"id": atom_id, "requirement": requirement, "weight": weight})
588
+ if not norm_atoms:
589
+ return fallback_rubric(qid, task, question, answer)
590
+ strict_notes = parsed.get("strict_notes", [])
591
+ if not isinstance(strict_notes, list):
592
+ strict_notes = [str(strict_notes)]
593
+ return {
594
+ "question_id": qid,
595
+ "question_type": task,
596
+ "question": question,
597
+ "reference_answer": answer,
598
+ "required_atoms": norm_atoms,
599
+ "strict_notes": [str(x) for x in strict_notes],
600
+ "prompt_version": ATOMIC_PROMPT_VERSION,
601
+ }
602
+
603
+
604
+ def get_or_build_rubric(qdata, rubric_data, rubric_file, deployment_name, api_version, args):
605
+ qid = qdata["question_id"]
606
+ existing = rubric_data["rubrics"].get(qid)
607
+ if (
608
+ existing
609
+ and existing.get("prompt_version") == ATOMIC_PROMPT_VERSION
610
+ and existing.get("required_atoms")
611
+ and not args.force_rebuild_rubric
612
+ ):
613
+ return existing
614
+
615
+ task = qdata["question_type"]
616
+ prompt = build_rubric_prompt(task, qdata["question"], qdata["answer"])
617
+ try:
618
+ parsed, raw = call_json_llm(prompt, deployment_name, api_version, args)
619
+ rubric = normalize_rubric(qid, task, qdata["question"], qdata["answer"], parsed)
620
+ rubric["rubric_raw_response"] = raw
621
+ except Exception as e:
622
+ print(f"[WARN] Falling back to single-atom rubric for {qid}: {e}", flush=True)
623
+ rubric = fallback_rubric(qid, task, qdata["question"], qdata["answer"])
624
+ rubric_data["rubrics"][qid] = rubric
625
+ write_json(rubric_file, rubric_data)
626
+ return rubric
627
+
628
+
629
+ def compute_atomic_scores(rubric, parsed):
630
+ atoms = rubric.get("required_atoms", [])
631
+ judgments_by_id = {}
632
+ raw_judgments = parsed.get("atom_judgments", []) if isinstance(parsed, dict) else []
633
+ if isinstance(raw_judgments, list):
634
+ for judgment in raw_judgments:
635
+ if not isinstance(judgment, dict):
636
+ continue
637
+ atom_id = str(judgment.get("id", "")).strip()
638
+ label = str(judgment.get("label", "")).strip()
639
+ if label not in ATOM_SCORES:
640
+ label = "incorrect"
641
+ judgments_by_id[atom_id] = {
642
+ "id": atom_id,
643
+ "label": label,
644
+ "score": ATOM_SCORES[label],
645
+ "rationale": str(judgment.get("rationale", "")),
646
+ }
647
+
648
+ norm_judgments = []
649
+ weighted_score = 0.0
650
+ total_weight = 0.0
651
+ for atom in atoms:
652
+ atom_id = atom["id"]
653
+ weight = float(atom.get("weight", 1.0))
654
+ judgment = judgments_by_id.get(
655
+ atom_id,
656
+ {"id": atom_id, "label": "missing", "score": 0.0, "rationale": "No judgment returned."},
657
+ )
658
+ judgment["requirement"] = atom["requirement"]
659
+ judgment["weight"] = weight
660
+ norm_judgments.append(judgment)
661
+ weighted_score += judgment["score"] * weight
662
+ total_weight += weight
663
+
664
+ extras = parsed.get("unsupported_or_contradictory", []) if isinstance(parsed, dict) else []
665
+ if not isinstance(extras, list):
666
+ extras = []
667
+ material_extras = [
668
+ x for x in extras
669
+ if isinstance(x, dict) and str(x.get("severity", "")).strip() == "material"
670
+ ]
671
+ absence_mismatch = bool(parsed.get("absence_mismatch", False)) if isinstance(parsed, dict) else False
672
+
673
+ strict_label = (
674
+ bool(norm_judgments)
675
+ and all(j["label"] == "correct" for j in norm_judgments)
676
+ and not material_extras
677
+ and not absence_mismatch
678
+ )
679
+ partial_score = weighted_score / total_weight if total_weight > 0 else 0.0
680
+ if absence_mismatch:
681
+ partial_score = 0.0
682
+ elif material_extras and partial_score > 0.8:
683
+ partial_score = 0.8
684
+
685
+ return {
686
+ "strict_label": strict_label,
687
+ "partial_score": round(partial_score, 4),
688
+ "atom_judgments": norm_judgments,
689
+ "unsupported_or_contradictory": extras,
690
+ "absence_mismatch": absence_mismatch,
691
+ "overall_rationale": str(parsed.get("overall_rationale", "")) if isinstance(parsed, dict) else "",
692
+ }
693
+
694
+
695
+ def judge_atomic(qdata, hypothesis, rubric, deployment_name, api_version, args):
696
+ prompt = build_atomic_eval_prompt(
697
+ qdata["question_type"],
698
+ qdata["question"],
699
+ qdata["answer"],
700
+ hypothesis,
701
+ rubric,
702
+ )
703
+ parsed, raw = call_json_llm(prompt, deployment_name, api_version, args)
704
+ scores = compute_atomic_scores(rubric, parsed)
705
+ return {
706
+ "model": args.eval_model_name,
707
+ "prompt_version": ATOMIC_PROMPT_VERSION,
708
+ "eval_mode": args.eval_mode,
709
+ "strict_label": scores["strict_label"],
710
+ "partial_score": scores["partial_score"],
711
+ "required_atoms": rubric["required_atoms"],
712
+ "atom_judgments": scores["atom_judgments"],
713
+ "unsupported_or_contradictory": scores["unsupported_or_contradictory"],
714
+ "absence_mismatch": scores["absence_mismatch"],
715
+ "overall_rationale": scores["overall_rationale"],
716
+ "raw_response": raw,
717
+ }
718
+
719
+
720
+ def should_skip_existing(existing_row, eval_mode):
721
+ if eval_mode == "legacy":
722
+ return "autoeval_label" in existing_row
723
+ atomic = existing_row.get("autoeval_atomic")
724
+ return bool(atomic and atomic.get("prompt_version") == ATOMIC_PROMPT_VERSION)
725
+
726
+
727
+ def safe_mean(values):
728
+ if not values:
729
+ return float("nan")
730
+ return float(np.mean(values))
731
+
732
+
733
+ def print_legacy_summary(logs, qtype2acc):
734
+ labels = [1 if x["autoeval_label"]["label"] else 0 for x in logs if "autoeval_label" in x]
735
+ print("Accuracy:", round(safe_mean(labels), 4))
736
+ for k, v in sorted(qtype2acc.items()):
737
+ print("\t{}: {} ({})".format(k, round(safe_mean(v), 4), len(v)))
738
+
739
+
740
+ def print_atomic_summary(logs, qtype2strict, qtype2partial, eval_mode):
741
+ strict_values = [
742
+ 1 if x["autoeval_atomic"]["strict_label"] else 0
743
+ for x in logs
744
+ if "autoeval_atomic" in x
745
+ ]
746
+ partial_values = [
747
+ float(x["autoeval_atomic"]["partial_score"])
748
+ for x in logs
749
+ if "autoeval_atomic" in x
750
+ ]
751
+ if eval_mode in ("strict", "both"):
752
+ print("Strict Accuracy:", round(safe_mean(strict_values), 4))
753
+ for k, v in sorted(qtype2strict.items()):
754
+ print("\t{}: {} ({})".format(k, round(safe_mean(v), 4), len(v)))
755
+ if eval_mode in ("partial", "both"):
756
+ print("Partial Score:", round(safe_mean(partial_values), 4))
757
+ for k, v in sorted(qtype2partial.items()):
758
+ print("\t{}: {} ({})".format(k, round(safe_mean(v), 4), len(v)))
759
+
760
+
761
+ def main():
762
+ parser = argparse.ArgumentParser()
763
+ parser.add_argument("--hyp_file", type=str, default=None)
764
+ parser.add_argument("--ref_file", type=str, required=True)
765
+ parser.add_argument("--eval_model_name", type=str, required=True)
766
+ parser.add_argument(
767
+ "--eval_mode",
768
+ type=str,
769
+ default="both",
770
+ choices=["legacy", "strict", "partial", "both"],
771
+ help="legacy uses the old yes/no judge; strict/partial/both use atomic JSON judging.",
772
+ )
773
+ parser.add_argument("--rubric_file", type=str, default=None)
774
+ parser.add_argument("--build_rubric_only", action="store_true", default=False)
775
+ parser.add_argument("--force_rebuild_rubric", action="store_true", default=False)
776
+ parser.add_argument("--result_file", type=str, default=None)
777
+ parser.add_argument("--debug", action="store_true", default=False)
778
+ parser.add_argument("--vllm", action="store_true", default=False)
779
+ parser.add_argument("--tritonai", action="store_true", default=False,
780
+ help="Use OpenAI-compatible LiteLLM proxy (set TRITONAI_API_KEY env var)")
781
+ parser.add_argument("--nvidia", action="store_true", default=False,
782
+ help="Use NVIDIA inference API (set NV_API_KEY env var)")
783
+ parser.add_argument("--verbose", action=argparse.BooleanOptionalAction, default=True)
784
+ args = parser.parse_args()
785
+
786
+ if not args.build_rubric_only and not args.hyp_file:
787
+ parser.error("--hyp_file is required unless --build_rubric_only is set")
788
+
789
+ metric_model = args.eval_model_name
790
+ deployment_name, api_version = model_zoo[metric_model]
791
+ references = read_json_or_jsonl(args.ref_file)
792
+ qid2qdata = {entry["question_id"]: entry for entry in references}
793
+ qid2qtype = {entry["question_id"]: entry["question_type"] for entry in references}
794
+ qtypes = set(qid2qtype.values())
795
+
796
+ rubric_data = None
797
+ rubric_file = args.rubric_file or default_rubric_file(args.ref_file)
798
+ if args.eval_mode != "legacy" or args.build_rubric_only:
799
+ rubric_data = load_rubric_file(rubric_file, args.ref_file)
800
+
801
+ if args.build_rubric_only:
802
+ for entry in tqdm(references, desc="building rubrics"):
803
+ get_or_build_rubric(entry, rubric_data, rubric_file, deployment_name, api_version, args)
804
+ print(f"Saved rubric file to {rubric_file}")
805
+ return
806
+
807
+ result_file = args.result_file or default_result_file(args.hyp_file, metric_model, args.eval_mode)
808
+ existing = read_existing_jsonl(result_file)
809
+ hypotheses = read_json_or_jsonl(args.hyp_file)
810
+
811
+ qtype2acc = {t: [] for t in qtypes}
812
+ qtype2strict = {t: [] for t in qtypes}
813
+ qtype2partial = {t: [] for t in qtypes}
814
+ logs = []
815
+
816
+ for entry in tqdm(hypotheses):
817
+ qid = entry.get("question_id")
818
+ if qid not in qid2qtype:
819
+ if qid is not None:
820
+ print(f"Warning: skipping {qid} as it is not in reference data.")
821
+ continue
822
+
823
+ if qid in existing and should_skip_existing(existing[qid], args.eval_mode):
824
+ existing_row = existing[qid]
825
+ logs.append(existing_row)
826
+ qtype = qid2qtype[qid]
827
+ if args.eval_mode == "legacy":
828
+ label = existing_row["autoeval_label"]["label"]
829
+ qtype2acc[qtype].append(1 if label else 0)
830
+ else:
831
+ atomic = existing_row["autoeval_atomic"]
832
+ qtype2strict[qtype].append(1 if atomic["strict_label"] else 0)
833
+ qtype2partial[qtype].append(float(atomic["partial_score"]))
834
+ continue
835
+
836
+ qdata = qid2qdata[qid]
837
+ qtype = qdata["question_type"]
838
+ hyp = entry["hypothesis"]
839
+
840
+ if args.eval_mode == "legacy":
841
+ prompt = get_anscheck_prompt(qtype, qdata["question"], qdata["answer"], hyp)
842
+ completion = llm_call(
843
+ deployment_name,
844
+ api_version,
845
+ prompt,
846
+ debug=args.debug,
847
+ vllm=args.vllm,
848
+ tritonai=args.tritonai,
849
+ nvidia=args.nvidia,
850
+ )
851
+ eval_response = completion.choices[0].message.content.strip()
852
+ last_line = next((l.strip().lower() for l in reversed(eval_response.splitlines()) if l.strip()), "")
853
+ label = last_line == "yes" or last_line.startswith("yes")
854
+ row = dict(entry)
855
+ row["autoeval_label"] = {
856
+ "model": metric_model,
857
+ "prompt_version": LEGACY_PROMPT_VERSION,
858
+ "label": label,
859
+ "raw_response": eval_response,
860
+ }
861
+ logs.append(row)
862
+ qtype2acc[qtype].append(1 if label else 0)
863
+ if args.verbose:
864
+ print(json.dumps({
865
+ "question": qdata["question"],
866
+ "answer": qdata["answer"],
867
+ "hypothesis": hyp,
868
+ "autoeval_label": label,
869
+ }, indent=4), flush=True)
870
+ append_jsonl(result_file, row)
871
+ continue
872
+
873
+ rubric = get_or_build_rubric(qdata, rubric_data, rubric_file, deployment_name, api_version, args)
874
+ try:
875
+ atomic_eval = judge_atomic(qdata, hyp, rubric, deployment_name, api_version, args)
876
+ except ValueError as _judge_err:
877
+ print(f"[WARN] judge_atomic failed for {qdata['question_id']}, writing zero score: {_judge_err}", flush=True)
878
+ atoms = rubric.get("required_atoms", [])
879
+ atomic_eval = {
880
+ "model": deployment_name,
881
+ "prompt_version": ATOMIC_PROMPT_VERSION,
882
+ "eval_mode": args.eval_mode,
883
+ "strict_label": False,
884
+ "partial_score": 0.0,
885
+ "required_atoms": atoms,
886
+ "atom_judgments": [{"id": a["id"], "label": "error", "score": 0.0, "rationale": "judge parse error", "requirement": a.get("requirement", ""), "weight": a.get("weight", 1.0)} for a in atoms],
887
+ "unsupported_or_contradictory": [],
888
+ "absence_mismatch": False,
889
+ "overall_rationale": f"Skipped: judge JSON parse error ({_judge_err})",
890
+ }
891
+ row = dict(entry)
892
+ row["autoeval_atomic"] = atomic_eval
893
+ logs.append(row)
894
+ qtype2strict[qtype].append(1 if atomic_eval["strict_label"] else 0)
895
+ qtype2partial[qtype].append(float(atomic_eval["partial_score"]))
896
+ if args.verbose:
897
+ print(json.dumps({
898
+ "question": qdata["question"],
899
+ "answer": qdata["answer"],
900
+ "hypothesis": hyp,
901
+ "strict_label": atomic_eval["strict_label"],
902
+ "partial_score": atomic_eval["partial_score"],
903
+ "atom_judgments": atomic_eval["atom_judgments"],
904
+ }, indent=4), flush=True)
905
+ append_jsonl(result_file, row)
906
+
907
+ if args.eval_mode == "legacy":
908
+ print_legacy_summary(logs, qtype2acc)
909
+ else:
910
+ print_atomic_summary(logs, qtype2strict, qtype2partial, args.eval_mode)
911
+ print(f"Rubric file: {rubric_file}")
912
+ print("Saved to", result_file)
913
+
914
+
915
+ if __name__ == "__main__":
916
+ main()
main.py ADDED
@@ -0,0 +1,1717 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import re
4
+ import json
5
+ import time
6
+ from json import JSONDecodeError
7
+ from datetime import datetime, timedelta
8
+ from typing import List, Dict, Any
9
+
10
+ from openai import OpenAI
11
+ try:
12
+ from openai import AzureOpenAI
13
+ from azure.identity import ChainedTokenCredential, AzureCliCredential, ManagedIdentityCredential, get_bearer_token_provider
14
+ # Azure scope for the OAuth bearer-token provider; override per deployment.
15
+ AZURE_OAUTH_SCOPE = os.environ.get("AZURE_OAUTH_SCOPE", "")
16
+ if AZURE_OAUTH_SCOPE:
17
+ credential = get_bearer_token_provider(ChainedTokenCredential(
18
+ AzureCliCredential(),
19
+ ManagedIdentityCredential(),
20
+ ), AZURE_OAUTH_SCOPE)
21
+ else:
22
+ credential = None
23
+ except ImportError:
24
+ AzureOpenAI = None
25
+ credential = None
26
+
27
+ from model_zoo import model_zoo
28
+ from memory import EpisodicMemoryStore, SemanticMemoryStore
29
+ try:
30
+ import tiktoken
31
+ except ImportError:
32
+ tiktoken = None
33
+
34
+ try:
35
+ from transformers import AutoTokenizer, PreTrainedTokenizerBase
36
+ except ImportError:
37
+ AutoTokenizer = None
38
+ PreTrainedTokenizerBase = ()
39
+ from collections import defaultdict
40
+
41
+ def get_hf_tokenizer_for_vllm(model_name: str):
42
+ return AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True)
43
+
44
+
45
+ # Azure OpenAI endpoint (set AZURE_OPENAI_ENDPOINT env var to your deployment URL).
46
+ endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT", "")
47
+
48
+ # OpenAI-compatible LiteLLM proxy URL (set LITELLM_BASE_URL env var to your proxy).
49
+ TRITONAI_BASE_URL = os.environ.get("LITELLM_BASE_URL", "")
50
+
51
+ # reading cached files
52
+ qid2plan = {}
53
+ if 'plan_cache' in os.environ:
54
+ plan_cache_file = os.environ['plan_cache']
55
+ if os.path.exists(plan_cache_file):
56
+ qid2plan = json.load(open(plan_cache_file))
57
+ else:
58
+ plan_cache_file = 'response_cache/qa/evolv_mem_v3_plan_cache_gpt5-1'
59
+
60
+ veri_reading_log_file = os.environ['reading_cache']
61
+ qid2rel_sess_ids = {}
62
+ if os.path.exists(veri_reading_log_file):
63
+ qid2rel_sess_ids = json.load(open(veri_reading_log_file))
64
+
65
+
66
+ # Cache file for retrieval results to avoid re-running expensive retrieval operations.
67
+ # Stores pre-computed search results for questions, including:
68
+ # - Question metadata (id, type, text, answer, dates)
69
+ # - Haystack information (session dates, content, IDs)
70
+ # - Retrieved results with query, ranked items, and evaluation metrics
71
+ # Format: JSONL file where each line contains a complete retrieval result for one question
72
+ retrieved_log_file = None
73
+ if 'ret_cache' in os.environ:
74
+ retrieved_log_file = os.environ['ret_cache']
75
+
76
+ print("loading existing retrieved results ...")
77
+ retrieved_data = [json.loads(line) for line in open(retrieved_log_file).readlines()]
78
+ retrieved_data_dict = {x['question_id']: x for x in retrieved_data}
79
+ valid_sess_set = set(json.load(open("dataset/all_sessions.json")).keys())
80
+
81
+
82
+ def parse_json(response_content):
83
+ """Safely parse JSON content from a string response."""
84
+ candidates = []
85
+
86
+ if '```json' in response_content:
87
+ start_idx = response_content.find('```json') + 7
88
+ end_idx = response_content.rfind('```')
89
+ if end_idx > start_idx:
90
+ candidates.append(response_content[start_idx:end_idx].strip())
91
+
92
+ # Also try brace-based extraction as fallback
93
+ brace_start = response_content.find('{')
94
+ brace_end = response_content.rfind('}') + 1
95
+ if brace_start >= 0 and brace_end > brace_start:
96
+ candidates.append(response_content[brace_start:brace_end].strip())
97
+
98
+ for json_block in candidates:
99
+ try:
100
+ result = json.loads(json_block)
101
+ return result
102
+ except (JSONDecodeError, ValueError):
103
+ continue
104
+
105
+ print(f"[Warning] Failed to decode JSON from response (all strategies failed)")
106
+ print(f"[Debug] Raw response content (truncated): {response_content[:500]}")
107
+ return {}
108
+
109
+
110
+ def _retryable_status(e):
111
+ # Try common attributes first — return any extractable HTTP status code
112
+ status = getattr(e, "status_code", None) or getattr(e, "http_status", None)
113
+ if status is not None:
114
+ return int(status)
115
+ resp = getattr(e, "response", None)
116
+ if resp is not None and getattr(resp, "status_code", None) is not None:
117
+ return int(resp.status_code)
118
+ # Fallback: infer from message text
119
+ msg = str(e).lower()
120
+ if "429" in msg or "rate limit" in msg:
121
+ return 429
122
+ if "500" in msg or "internal server error" in msg:
123
+ return 500
124
+ if "503" in msg or "API Configuration unavailable" in msg:
125
+ return 503
126
+ return None
127
+
128
+
129
+ MAX_CONTEXT_TOKENS = 272_000
130
+ #MAX_CONTEXT_TOKENS = 256_000
131
+
132
+
133
+ class CharacterEncoder:
134
+ """Conservative fallback when tokenizer packages are unavailable."""
135
+
136
+ def encode(self, text, **kwargs):
137
+ return list(text)
138
+
139
+ def decode(self, toks):
140
+ return "".join(toks)
141
+
142
+
143
+ def _get_encoder(model_name: str):
144
+ """
145
+ Return a token encoder for the given model name.
146
+ Cached to avoid reloading HF tokenizers on every call.
147
+ """
148
+ # Prefer explicit handling for Qwen models first
149
+ # Note: Adjust the path below if you have the model downloaded locally
150
+ if AutoTokenizer is not None and any(k in model_name for k in ["Qwen3", "Qwen/", "Qwen"]):
151
+ try:
152
+ return AutoTokenizer.from_pretrained(
153
+ "Qwen/Qwen3-30B-A3B-Instruct-2507",
154
+ trust_remote_code=True,
155
+ use_fast=False
156
+ )
157
+ except Exception as e:
158
+ print(f"[WARN] Failed to load Qwen tokenizer: {e}. Falling back to tiktoken.")
159
+
160
+ # For non-Qwen models, rely on tiktoken's mapping when possible
161
+ if tiktoken is not None:
162
+ try:
163
+ return tiktoken.encoding_for_model(model_name)
164
+ except Exception:
165
+ # Generic safe fallback
166
+ return tiktoken.get_encoding("cl100k_base")
167
+ return CharacterEncoder()
168
+
169
+
170
+ def _truncate_to_tokens(text, enc, max_tokens) -> str:
171
+ """
172
+ Truncates text to the last `max_tokens`.
173
+ Compatible with both tiktoken and Hugging Face AutoTokenizers.
174
+ """
175
+ # Handle Hugging Face Tokenizers
176
+ if PreTrainedTokenizerBase and isinstance(enc, PreTrainedTokenizerBase):
177
+ # add_special_tokens=False is crucial here to avoid double counting
178
+ # or inserting BOS/EOS in the middle of text during length checks
179
+ toks = enc.encode(text, add_special_tokens=False)
180
+ # Handle tiktoken
181
+ else:
182
+ toks = enc.encode(text, disallowed_special=())
183
+
184
+ if len(toks) <= max_tokens:
185
+ return text
186
+
187
+ # Keep the tail (usually the most relevant for instructions / recent context)
188
+ toks = toks[-max_tokens:]
189
+
190
+ return enc.decode(toks)
191
+
192
+
193
+ def truncate_chat_prompt(tokenizer, messages, max_context, max_output, overhead=256):
194
+ # Apply the model's chat template so token counting matches the server
195
+ prompt_text = tokenizer.apply_chat_template(
196
+ messages,
197
+ tokenize=False,
198
+ add_generation_prompt=True,
199
+ )
200
+ input_ids = tokenizer(prompt_text, add_special_tokens=False).input_ids
201
+
202
+ budget = max_context - max_output - overhead
203
+ if budget < 0:
204
+ raise ValueError("max_output_tokens + overhead exceeds max_context_tokens")
205
+
206
+ if len(input_ids) > budget:
207
+ input_ids = input_ids[-budget:] # or keep the *start* depending on your needs
208
+ prompt_text = tokenizer.decode(input_ids, skip_special_tokens=False)
209
+
210
+ return prompt_text
211
+
212
+
213
+ def llm_call(deployment_name: str,
214
+ api_version: str,
215
+ _prompt: str,
216
+ max_context_tokens: int = MAX_CONTEXT_TOKENS,
217
+ max_output_tokens: int = 1024,
218
+ extra_overhead_tokens: int = 32,
219
+ debug: bool = False,
220
+ vllm: bool = False,
221
+ tritonai: bool = False,
222
+ nvidia: bool = False):
223
+ if nvidia:
224
+ client = OpenAI(
225
+ api_key=os.getenv("NV_API_KEY"),
226
+ base_url="https://inference-api.nvidia.com/v1",
227
+ )
228
+ elif tritonai:
229
+ client = OpenAI(
230
+ api_key=os.getenv("TRITONAI_API_KEY"),
231
+ base_url=TRITONAI_BASE_URL,
232
+ )
233
+ max_context_tokens = 131_072 # DeepSeek R1 128K context
234
+ # DeepSeek R1 uses thinking tokens before the answer; raise output budget
235
+ if max_output_tokens < 4096:
236
+ max_output_tokens = 4096
237
+ elif vllm:
238
+ # Use local vLLM OpenAI-compatible server
239
+ vllm_base_url = os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1")
240
+ vllm_api_key = os.getenv("VLLM_API_KEY", "EMPTY")
241
+ client = OpenAI(
242
+ base_url=vllm_base_url,
243
+ api_key=vllm_api_key,
244
+ )
245
+ # Override deployment_name with the vLLM-served model name when set
246
+ # (needed when main model is LiteLLM proxy but reading uses local vLLM)
247
+ deployment_name = os.getenv("VLLM_MODEL_NAME", deployment_name)
248
+ # vLLM Qwen3-30B-A3B-Instruct-2507 has 131,072-token context
249
+ max_context_tokens = 131_072
250
+ elif debug:
251
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
252
+ else:
253
+ client = AzureOpenAI(
254
+ azure_endpoint=endpoint,
255
+ azure_ad_token_provider=credential,
256
+ api_version=api_version,
257
+ )
258
+
259
+ enc = _get_encoder(deployment_name)
260
+
261
+ # How many tokens we can spend on the input
262
+ budget = max_context_tokens - max_output_tokens - extra_overhead_tokens
263
+ if budget < 0:
264
+ raise ValueError("max_output_tokens + overhead exceeds max_context_tokens")
265
+
266
+ prompt_truncated = _truncate_to_tokens(_prompt, enc, budget)
267
+
268
+ # Strip control characters and fix broken Unicode that break JSON serialization
269
+ if nvidia or tritonai:
270
+ prompt_truncated = prompt_truncated.encode('utf-8', errors='replace').decode('utf-8', errors='replace')
271
+ prompt_truncated = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', prompt_truncated)
272
+ # Verify it's valid JSON-serializable
273
+ json.dumps(prompt_truncated)
274
+
275
+ # OpenAI-compatible proxy requires at least one user message; Azure accepts system role
276
+ msg_role = "user" if (tritonai or nvidia) else "system"
277
+ kwargs = {
278
+ 'model': deployment_name,
279
+ 'messages':[
280
+ {"role": msg_role, "content": prompt_truncated}
281
+ ]
282
+ }
283
+
284
+ while True:
285
+ try:
286
+ completion = client.chat.completions.create(**kwargs)
287
+ break
288
+ except Exception as e:
289
+ from openai import APITimeoutError as _APITimeoutError
290
+ if isinstance(e, _APITimeoutError):
291
+ print(f"[WARN] APITimeoutError from LLM; sleeping 30s then retrying...", flush=True)
292
+ time.sleep(30)
293
+ continue
294
+ st = _retryable_status(e)
295
+ # 404 from LiteLLM proxy/Bedrock is intermittent (model temporarily unavailable)
296
+ retryable = (429, 500, 503, 403) + ((404,) if tritonai else ())
297
+ if st in retryable:
298
+ print(
299
+ f"[WARN] q_idx={di} HTTP {st} from LLM; sleeping 60s then retrying...",
300
+ flush=True
301
+ )
302
+ time.sleep(60)
303
+ continue
304
+ # Non-retryable -> re-raise
305
+ print('One exception captured', repr(e), flush=True)
306
+ raise
307
+
308
+ #answer = (completion.choices[0].message.content or "").strip()
309
+ #return answer
310
+ return completion
311
+
312
+ def custom_to_iso8601(time_str):
313
+ """
314
+ Convert '2023/04/10 (Mon) 23:07' to '2023-04-10T23:07:00'
315
+ """
316
+ # Remove the weekday (e.g., "(Mon)")
317
+ clean = time_str.split('(')[0].strip() + ' ' + time_str.split(')')[-1].strip()
318
+ # Parse the cleaned string
319
+ dt = datetime.strptime(clean, "%Y/%m/%d %H:%M")
320
+ # Format as ISO 8601
321
+ return dt.isoformat()
322
+
323
+
324
+ def evaluate_retrieval(recalled_docs, correct_docs, k=10):
325
+ #recalled_docs = set(corpus_ids[idx] for idx in rankings[:k])
326
+ recall_any = float(any(doc in recalled_docs for doc in correct_docs))
327
+ recall_all = float(all(doc in recalled_docs for doc in correct_docs))
328
+ return recall_any, recall_all
329
+
330
+
331
+ def print_average_metrics(retrieval_metric_list):
332
+ metric_sums = defaultdict(float)
333
+ metric_counts = defaultdict(int)
334
+
335
+ for metric in retrieval_metric_list:
336
+ for k, v in metric.items():
337
+ metric_sums[k] += v
338
+ metric_counts[k] += 1
339
+
340
+ print("\t\t Average metrics:")
341
+ for k in sorted(metric_sums):
342
+ avg = metric_sums[k] / metric_counts[k]
343
+ print(f"\t\t {k}: {avg:.4f}")
344
+
345
+
346
+ # Load prompt template
347
+ prompt_path = "prompts/agentic_retrieval_prompt.txt"
348
+ with open(prompt_path, "r", encoding="utf-8") as f:
349
+ stg_prompt = f.read()
350
+
351
+
352
+ class ChatHistory:
353
+ def __init__(self, data: Dict[str, Any] = None, sessions: List = None):
354
+ assert not (data is not None and sessions is not None), "ChatHistory: Only one of data or sessions may be provided."
355
+
356
+ if data is not None: # From raw data dict
357
+ self.raw_data = data
358
+ self.sessions = []
359
+ self._parse_sessions()
360
+ elif sessions is not None: # From provided sessions list
361
+ self.sessions = sessions
362
+ self.messages = []
363
+ for sess in self.sessions:
364
+ session_id = sess['session_id']
365
+ timestamp = sess['timestamp']
366
+ for turn_idx, msg in enumerate(sess['session']):
367
+ entry = {
368
+ "role": msg.get("role"),
369
+ "content": msg.get("content"),
370
+ "session_id": session_id,
371
+ "turn_index": turn_idx,
372
+ "timestamp": timestamp,
373
+ "iso_datetime": timestamp.isoformat(),
374
+ "has_answer": msg.get("has_answer", False)
375
+ }
376
+ self.messages.append(entry)
377
+ else:
378
+ self.sessions = []
379
+ self.messages = []
380
+
381
+ def get_contents(self, granularity='turn', _format='json') -> list:
382
+ if granularity == "turn":
383
+ if _format == "json":
384
+ return [json.dumps(msg) for msg in self.messages]
385
+ else:
386
+ return [msg['content'] for msg in self.messages]
387
+ else: # granularity == 'session'
388
+ if _format == "json":
389
+ return [json.dumps(session) for session in self.sessions]
390
+ else:
391
+ return [json.dumps({"role": session["role"], "content": session["content"]})
392
+ for session in self.sessions]
393
+
394
+ def to_prompt(self, granularity='session', _format="json"):
395
+ history_str = ""
396
+ for session in self.sessions:
397
+ sess_str = json.dumps([{"role": x["role"], "content": x["content"]} for x in session['session']])
398
+ history_str += f"Session Date: {session['session_date']}\nSession Content:\n{sess_str}\n"
399
+ return history_str
400
+
401
+ def get_session_ids(self):
402
+ return [s['session_id'] for s in self.sessions]
403
+
404
+ @staticmethod
405
+ def _parse_date(date_str: str) -> datetime:
406
+ # Convert '2023/04/10 (Mon) 17:50' to datetime
407
+ # Remove weekday in parentheses
408
+ date_part, time_part = date_str.split('(')[0].strip(), date_str.split(')')[-1].strip()
409
+ dt = datetime.strptime(date_part + time_part, "%Y/%m/%d%H:%M")
410
+ return dt
411
+
412
+ def _parse_sessions(self):
413
+ """
414
+ Flattens sessions into a list of messages, each with ISO 8601 date, session ID, and turn index
415
+ """
416
+ self.messages = []
417
+ for date_str, session_id, session, topic in zip(
418
+ self.raw_data['haystack_dates'],
419
+ self.raw_data['haystack_session_ids'],
420
+ self.raw_data['haystack_sessions'],
421
+ self.raw_data['haystack_topics']
422
+ ):
423
+ timestamp = self._parse_date(date_str)
424
+ for turn_idx, msg in enumerate(session):
425
+ entry = {
426
+ "role": msg.get("role"),
427
+ "content": msg.get("content"),
428
+ "session_id": session_id,
429
+ "turn_index": turn_idx,
430
+ "timestamp": timestamp,
431
+ "iso_datetime": timestamp.isoformat(),
432
+ "session_date": date_str,
433
+ "has_answer": msg.get("has_answer", False)
434
+ }
435
+ self.messages.append(entry)
436
+ self.sessions.append({
437
+ "session_date": date_str,
438
+ "timestamp": timestamp,
439
+ "session_id": session_id,
440
+ "session": session,
441
+ "topic": topic,
442
+ })
443
+ # Optionally, sort by time (ascending)
444
+ #self.sessions.sort(key=lambda x: x["timestamp"])
445
+ #self.messages.sort(key=lambda x: x["timestamp"])
446
+
447
+ def __len__(self):
448
+ return len(self.sessions)
449
+
450
+ def __getitem__(self, idx) -> Dict[str, any]:
451
+ return self.sessions[idx] # Return session dict
452
+
453
+ def get_item_by_index(self, idx):
454
+ if isinstance(idx, range) or isinstance(idx, list):
455
+ max_idx = len(self.sessions)
456
+ valid_indices = [i for i in idx if 0 <= i < max_idx]
457
+ selected_sessions = [self.sessions[i] for i in valid_indices]
458
+ return ChatHistory(sessions=selected_sessions)
459
+ else:
460
+ raise ValueError("Input must be a list or range of indices.")
461
+
462
+ def get_item_by_session_ids(self, sess_set):
463
+ if not isinstance(sess_set, set):
464
+ sess_set = set(sess_set)
465
+ new_sessions = []
466
+ for sess in self.sessions:
467
+ if sess['session_id'] in sess_set:
468
+ new_sessions.append(sess)
469
+
470
+ return ChatHistory(sessions=new_sessions)
471
+
472
+ def get_item_by_ranked_session(self, sess_id_sorted):
473
+ new_sessions = []
474
+ for sess_id in sess_id_sorted:
475
+ for sess in self.sessions:
476
+ if sess['session_id'] in sess_id:
477
+ new_sessions.append(sess)
478
+
479
+ return ChatHistory(sessions=new_sessions)
480
+
481
+ def get_item_by_topics(self, topics):
482
+ new_sessions = []
483
+ new_sess_ids = set()
484
+ for sess in self.sessions:
485
+ for tp in sess['topic']:
486
+ if tp in topics and sess['session_id'] not in new_sess_ids:
487
+ new_sessions.append(sess)
488
+ new_sess_ids.add(sess['session_id'])
489
+ break
490
+
491
+ return ChatHistory(sessions=new_sessions)
492
+
493
+ def merge_rel_sess(self, new_sessions):
494
+ # Gather all current and new sessions in a dict keyed by session_id
495
+ all_sessions = {s["session_id"]: s for s in self.sessions}
496
+
497
+ # add if new
498
+ for s in new_sessions.sessions:
499
+ if s["session_id"] not in all_sessions:
500
+ all_sessions[s["session_id"]] = s
501
+
502
+ # Reconstruct raw_data for new ChatHistory
503
+ merged_raw_data = {
504
+ "haystack_dates": [s["session_date"] for k, s in all_sessions.items()],
505
+ "haystack_session_ids": [s["session_id"] for k, s in all_sessions.items()],
506
+ "haystack_sessions": [s["session"] for k, s in all_sessions.items()],
507
+ "haystack_topics": [s["topic"] for k, s in all_sessions.items()],
508
+ }
509
+ self.sessions = ChatHistory(merged_raw_data)
510
+
511
+
512
+
513
+ def generate_keywords(question: str, deployment_name, api_version, debug=False, vllm=None,
514
+ tritonai=False, nvidia=False) -> List[str]:
515
+ # read prompt from `keyword_search_prompt.txt` file
516
+ with open('prompts/keyword_search_prompt.txt') as f:
517
+ prompt_template = f.read()
518
+
519
+ prompt = prompt_template + question
520
+ # Call the LLM to generate keywords
521
+ completion = llm_call(
522
+ deployment_name,
523
+ api_version,
524
+ prompt,
525
+ debug=debug,
526
+ vllm=vllm,
527
+ tritonai=tritonai,
528
+ nvidia=nvidia,
529
+ )
530
+
531
+ response_content = (completion.choices[0].message.content or "").strip()
532
+ result = parse_json(response_content)
533
+ keywords = result["keywords"] if "keywords" in result else []
534
+ return keywords
535
+
536
+ def keyword_search(chat_history: ChatHistory, keywords: list):
537
+ print(f"\t\t** keyword search **: {keywords}")
538
+ # Gather all messages that match
539
+ start_time = time.time()
540
+ matched_msgs = [
541
+ msg for msg in chat_history.messages
542
+ if any(kw.lower() in (msg.get("content") or "").lower() for kw in keywords)
543
+ ]
544
+ end_time = time.time()
545
+ execution_time = end_time - start_time
546
+
547
+ if matched_msgs:
548
+ new_sess_ids = set()
549
+ for msg in matched_msgs:
550
+ key = msg["session_id"]
551
+ new_sess_ids.add(key)
552
+ new_chat_history = chat_history.get_item_by_session_ids(new_sess_ids)
553
+ else:
554
+ new_chat_history = ChatHistory()
555
+
556
+ return new_chat_history
557
+
558
+
559
+ def is_turn_id(text):
560
+ pattern = r'_\d+$'
561
+ return bool(re.search(pattern, text))
562
+
563
+
564
+ def embedding_search(chat_history: ChatHistory, qid: str, top_k: int = 50, exclude_sess=None):
565
+ print("\t\t** embedding based retrieval **")
566
+ new_sess_ids = []
567
+
568
+ curr_all_sess = set(chat_history.get_session_ids())
569
+ if exclude_sess:
570
+ curr_all_sess = curr_all_sess - set(exclude_sess)
571
+
572
+ for item in retrieved_data_dict[qid]["retrieval_results"]["ranked_items"]:
573
+ if item["corpus_id"] in valid_sess_set:
574
+ sid = item["corpus_id"]
575
+ else:
576
+ tokens = item["corpus_id"].split("_")
577
+
578
+ if "_turn" in item["corpus_id"]:
579
+ sid = item["corpus_id"].split("_turn")[0]
580
+ elif "_fact" in item["corpus_id"]:
581
+ sid = item["corpus_id"].split("_fact")[0]
582
+ elif "noans" in item["corpus_id"]:
583
+ sid = item["corpus_id"].replace("noans", "answer")
584
+ elif is_turn_id(item["corpus_id"]):
585
+ sid = "_".join(tokens[:-1]) # remove turn index
586
+ else:
587
+ sid = item["corpus_id"]
588
+
589
+ if sid not in valid_sess_set:
590
+ print(item["corpus_id"], sid)
591
+
592
+ assert sid in valid_sess_set
593
+
594
+ if sid in curr_all_sess:
595
+ new_sess_ids.append(sid)
596
+ if len(new_sess_ids) == top_k:
597
+ break
598
+ new_chat_history = chat_history.get_item_by_ranked_session(new_sess_ids)
599
+ return new_chat_history
600
+
601
+
602
+ def filter_out_by_embedding(chat_history: ChatHistory, qid: str, top_k: int = 50):
603
+ print("\t\t** [filter_out] embedding based retrieval - loading existing results ...")
604
+ new_sess_ids = []
605
+ curr_all_sess = set(chat_history.get_session_ids())
606
+ for item in retrieved_data_dict[qid]["retrieval_results"]["ranked_items"]:
607
+ if item["corpus_id"] in valid_sess_set:
608
+ sid = item["corpus_id"]
609
+ else:
610
+ tokens = item["corpus_id"].split("_")
611
+
612
+ if "_turn" in item["corpus_id"]:
613
+ sid = item["corpus_id"].split("_turn")[0]
614
+ elif "_fact" in item["corpus_id"]:
615
+ sid = item["corpus_id"].split("_fact")[0]
616
+ elif "noans" in item["corpus_id"]:
617
+ sid = item["corpus_id"].replace("noans", "answer")
618
+ elif is_turn_id(item["corpus_id"]):
619
+ sid = "_".join(tokens[:-1]) # remove turn index
620
+ else:
621
+ sid = item["corpus_id"]
622
+
623
+ if sid not in valid_sess_set:
624
+ print(item["corpus_id"], sid)
625
+
626
+ assert sid in valid_sess_set
627
+
628
+ if sid in curr_all_sess:
629
+ new_sess_ids.append(sid)
630
+ if len(new_sess_ids) == top_k:
631
+ break
632
+ new_chat_history = chat_history.get_item_by_ranked_session(new_sess_ids)
633
+ return new_chat_history, 0.0
634
+
635
+
636
+ def flat_embedding_top_k_ids(qid: str, haystack_sess_ids: List[str], top_k: int) -> List[str]:
637
+ """
638
+ Pull the top_k session IDs from the global GTE retrieval cache (retrieved_data_dict),
639
+ constrained to the question's haystack. Mirrors embedding_search() but operates on
640
+ IDs only (no ChatHistory). Used by hier_union to widen the Stage-2 pool.
641
+ """
642
+ haystack_set = set(haystack_sess_ids)
643
+ ids: List[str] = []
644
+ for item in retrieved_data_dict[qid]["retrieval_results"]["ranked_items"]:
645
+ cid = item["corpus_id"]
646
+ if cid in valid_sess_set:
647
+ sid = cid
648
+ elif "_turn" in cid:
649
+ sid = cid.split("_turn")[0]
650
+ elif "_fact" in cid:
651
+ sid = cid.split("_fact")[0]
652
+ elif "noans" in cid:
653
+ sid = cid.replace("noans", "answer")
654
+ elif is_turn_id(cid):
655
+ sid = "_".join(cid.split("_")[:-1])
656
+ else:
657
+ sid = cid
658
+ if sid in haystack_set and sid not in ids:
659
+ ids.append(sid)
660
+ if len(ids) == top_k:
661
+ break
662
+ return ids
663
+
664
+
665
+ def semantic_embedding_search(
666
+ qid: str,
667
+ haystack_sess_ids: List[str],
668
+ semantic_retrieved_dict: dict,
669
+ top_k: int = 50,
670
+ ) -> List[str]:
671
+ """
672
+ Like embedding_search() but reads from the pre-computed semantic-gte retrieval cache.
673
+ Returns an ordered list of up to top_k session IDs from the haystack.
674
+ """
675
+ print("\t\t** semantic embedding retrieval **")
676
+ haystack_set = set(haystack_sess_ids)
677
+ ranked_ids: List[str] = []
678
+ for item in semantic_retrieved_dict[qid]["retrieval_results"]["ranked_items"]:
679
+ sid = item["corpus_id"] # already session-level (no turn suffix)
680
+ if sid in haystack_set and sid not in ranked_ids:
681
+ ranked_ids.append(sid)
682
+ if len(ranked_ids) == top_k:
683
+ break
684
+ return ranked_ids
685
+
686
+
687
+ def time_filter(chat_history: ChatHistory, start_date: str, end_date: str) -> ChatHistory:
688
+ # Returns all messages with timestamp in the ISO date range [start_date, end_date] (inclusive).
689
+ start_time = time.time()
690
+ try:
691
+ start = datetime.fromisoformat(start_date)
692
+ end = datetime.fromisoformat(end_date)
693
+ filtered_msgs = [msg for msg in chat_history.messages if start.date() <= msg["timestamp"].date() <= end.date()]
694
+ except Exception as e:
695
+ print("Converting date error: ", e)
696
+ filtered_msgs = []
697
+ end_time = time.time()
698
+ execution_time = end_time - start_time
699
+
700
+ if filtered_msgs:
701
+ new_sess_ids = set()
702
+ for msg in filtered_msgs:
703
+ key = msg["session_id"]
704
+ new_sess_ids.add(key)
705
+ new_chat_history = chat_history.get_item_by_session_ids(new_sess_ids)
706
+ else:
707
+ new_chat_history = ChatHistory()
708
+
709
+ return new_chat_history
710
+
711
+
712
+ class RetrievalAgent:
713
+ def __init__(
714
+ self,
715
+ history: List[Dict],
716
+ topics: List[str],
717
+ user_profile: str = None,
718
+ debug: bool = False,
719
+ vllm: bool = False,
720
+ vllm_reading: bool = False,
721
+ tritonai: bool = False,
722
+ nvidia: bool = False,
723
+ n_chunks: int = 10,
724
+ topic_filter: bool = True,
725
+ no_time_filter: bool = False,
726
+ semantic_store: SemanticMemoryStore = None,
727
+ episodic_store: EpisodicMemoryStore = None,
728
+ hier_v2: bool = False,
729
+ hier_union: bool = False,
730
+ hier_union_flat_k: int = 20,
731
+ no_early_answer: bool = False,
732
+ ):
733
+ self.chat_history = history
734
+ self.user_profile = user_profile
735
+ self.topics = topics
736
+ self.rel_sess = ChatHistory()
737
+ self.evidence = []
738
+ self.debug = debug
739
+ self.vllm = vllm
740
+ self.vllm_reading = vllm_reading # use vLLM only for _read_and_verify
741
+ self.tritonai = tritonai # use LiteLLM proxy for non-reading LLM calls
742
+ self.nvidia = nvidia # use NVIDIA inference API
743
+ self.no_time_filter = no_time_filter # skip time_filter steps in strategy
744
+ self.n_chunks = n_chunks
745
+ self.topic_filter = topic_filter
746
+ self.semantic_store = semantic_store
747
+ self.episodic_store = episodic_store
748
+ self.hier_v2 = hier_v2
749
+ self.hier_union = hier_union
750
+ self.hier_union_flat_k = hier_union_flat_k
751
+ self.no_early_answer = no_early_answer
752
+ self.token_budget = {
753
+ 'planning': {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0},
754
+ 'verification_reading': {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0},
755
+ 'is_answerable': {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0},
756
+ 'final_answer': {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0},
757
+ }
758
+
759
+ if self.user_profile:
760
+ with open('prompts/read_and_extract_prompt.txt') as f:
761
+ self.read_prompt_template = f.read()
762
+ else: # ablation: wo_profile
763
+ with open('prompts/agentic_retrieval_prompt_wo_profile.txt') as f:
764
+ self.read_prompt_template = f.read()
765
+
766
+ def _track_usage(self, component: str, completion) -> None:
767
+ """Accumulate prompt/completion token counts for a named component."""
768
+ usage = getattr(completion, 'usage', None)
769
+ if usage is None:
770
+ return
771
+ self.token_budget[component]['prompt_tokens'] += getattr(usage, 'prompt_tokens', 0) or 0
772
+ self.token_budget[component]['completion_tokens'] += getattr(usage, 'completion_tokens', 0) or 0
773
+ self.token_budget[component]['n_calls'] += 1
774
+
775
+ def get_token_budget(self) -> dict:
776
+ """Return token_budget with an added 'total' entry."""
777
+ total = {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0}
778
+ for v in self.token_budget.values():
779
+ total['prompt_tokens'] += v['prompt_tokens']
780
+ total['completion_tokens'] += v['completion_tokens']
781
+ total['n_calls'] += v['n_calls']
782
+ return {**self.token_budget, 'total': total}
783
+
784
+ def is_answerable(self, question: str, question_date: str, retrieved_sess, evidence, model_info, context_str: str = None) -> bool:
785
+ context = ""
786
+ for k, v in evidence.items(): # key: "profile", "tags", "chat_clues"
787
+ for e in v:
788
+ context += f"{e}\n"
789
+
790
+ # context_str overrides retrieved_sess.to_prompt() (used in Stage 1 with semantic context)
791
+ sess_str = context_str if context_str is not None else retrieved_sess.to_prompt()
792
+
793
+ # Include user profile in the prompt when available
794
+ profile_section = ""
795
+ if self.user_profile:
796
+ profile_section = f"\nUser Profile:\n{self.user_profile}\n"
797
+
798
+ ia_prompt_prefix = f"""
799
+ You are a decision-making agent tasked with determining when sufficient information has been gathered to answer a user's question.
800
+
801
+ Your Task:
802
+ Analyze the provided question, current date, available memory context, and available evidence to make a binary decision: Answerable or Not answerable. If the information is not sufficient, explain what specific information is needed to provide hints for the next retrieval stage.
803
+
804
+ Question: {question}
805
+ Current Date: {question_date}
806
+ {profile_section}
807
+ Memory Context:
808
+ """
809
+
810
+ output_str = """
811
+ Output (always JSON — choose fields per the rules above)
812
+
813
+ Case 1 — Answerable:
814
+ {
815
+ "is_answerable": true,
816
+ "answer": "<concise answer grounded strictly in the Evidence>"
817
+ }
818
+
819
+ Case 2 — Not answerable:
820
+ {
821
+ "is_answerable": false,
822
+ "info_needed": ["<specific missing detail 1>", "<specific missing detail 2>"]
823
+ }
824
+ """
825
+ deployment_name, api_version = model_info
826
+
827
+ # ------------------------------------------------------------------
828
+ # Token-based truncation: keep the *end* of sess_str under budget
829
+ # ------------------------------------------------------------------
830
+ enc = _get_encoder(deployment_name)
831
+
832
+ # Model/context limits
833
+ if self.vllm or self.tritonai:
834
+ # Large context for vLLM / LiteLLM proxy models
835
+ model_max_ctx = 131_072
836
+ else:
837
+ model_max_ctx = MAX_CONTEXT_TOKENS
838
+
839
+ max_output_tokens = 1024
840
+ extra_overhead_tokens = 32
841
+
842
+ # Total budget available for input tokens
843
+ budget = model_max_ctx - max_output_tokens - extra_overhead_tokens
844
+ if budget <= 0:
845
+ raise ValueError(
846
+ f"max_output_tokens ({max_output_tokens}) + overhead "
847
+ f"({extra_overhead_tokens}) exceeds model_max_ctx ({model_max_ctx})"
848
+ )
849
+
850
+ # Token lengths of static pieces
851
+ prefix_tokens = enc.encode(ia_prompt_prefix, disallowed_special=())
852
+ output_tokens = enc.encode(output_str, disallowed_special=())
853
+ sess_tokens = enc.encode(sess_str, disallowed_special=())
854
+
855
+ # Budget for sess_str tokens
856
+ available_for_sess = budget - len(prefix_tokens) - len(output_tokens)
857
+
858
+ if available_for_sess <= 0:
859
+ # No room for history at all; drop it
860
+ truncated_sess_str = ""
861
+ else:
862
+ if len(sess_tokens) > available_for_sess:
863
+ # Keep the *last* available_for_sess tokens (drop oldest history)
864
+ truncated_sess_tokens = sess_tokens[-available_for_sess:]
865
+ truncated_sess_str = enc.decode(truncated_sess_tokens)
866
+ else:
867
+ truncated_sess_str = sess_str
868
+
869
+ ia_prompt = ia_prompt_prefix + truncated_sess_str + "\n"
870
+
871
+ # ------------------------------------------------------------------
872
+ # Call the LLM with the already-truncated prompt
873
+ # ------------------------------------------------------------------
874
+
875
+ completion = llm_call(
876
+ deployment_name,
877
+ api_version,
878
+ ia_prompt + output_str,
879
+ max_context_tokens=model_max_ctx, # matches what we used for budgeting
880
+ max_output_tokens=max_output_tokens,
881
+ extra_overhead_tokens=extra_overhead_tokens,
882
+ debug=self.debug,
883
+ vllm=self.vllm,
884
+ tritonai=self.tritonai,
885
+ nvidia=self.nvidia,
886
+ )
887
+ self._track_usage('is_answerable', completion)
888
+ response_content = (completion.choices[0].message.content or "").strip()
889
+ print(f"\t\t[Agent] is_answerable: {response_content}")
890
+ result = parse_json(response_content)
891
+ if not result:
892
+ print("[Warning] Empty or invalid JSON in is_answerable() response.")
893
+ return {"is_answerable": False, "info_needed": ["Parsing failed"]}
894
+ return result
895
+
896
+ def _read_and_verify(self, question: str, question_date: str, evidence: ChatHistory, n_chunks=10) -> ChatHistory:
897
+ # Read evidence and and select
898
+ relevant_indices = []
899
+ evidence_list = []
900
+ max_idx = len(evidence)
901
+ for j in range(0, len(evidence), n_chunks):
902
+ chunk_range = range(j, j+n_chunks)
903
+ valid_indices = [i for i in chunk_range if 0 <= i < max_idx]
904
+ cur_chunk = evidence.get_item_by_index(valid_indices)
905
+ cur_chunk_sess = [[{"role": m["role"], "content": m["content"]} for m in sess['session']]
906
+ for sess in cur_chunk.sessions]
907
+ cur_chunk_sess_date = [sess['session_date'] for sess in cur_chunk.sessions]
908
+ sess_input_str = "\n".join([
909
+ f"### Session Index: {i}\n### Session Date: {sess_date}\n\n{json.dumps(sess)}\n"
910
+ for i, (sess, sess_date) in enumerate(zip(cur_chunk_sess, cur_chunk_sess_date))
911
+ ])
912
+ _prompt = self.read_prompt_template + f"## Question: {question}\n## Question Date: {question_date}\n## Session list:\n\n{sess_input_str}\nNow, identify **only the sessions strictly necessary to answer the question**."
913
+ completion = llm_call(
914
+ deployment_name,
915
+ api_version,
916
+ _prompt,
917
+ debug=self.debug,
918
+ vllm=self.vllm or self.vllm_reading,
919
+ nvidia=self.nvidia,
920
+ )
921
+ self._track_usage('verification_reading', completion)
922
+ response_content = (completion.choices[0].message.content or "").strip()
923
+
924
+ print(f"\t\t {valid_indices[0]}~{valid_indices[-1]}: response: {response_content.replace(chr(10), '')}")
925
+ try:
926
+ start_idx = response_content.rfind('{')
927
+ end_idx = response_content.rfind('}') + 1
928
+ json_block = response_content[start_idx:end_idx]
929
+ result = json.loads(json_block)
930
+
931
+ if "index" in result and result['index'] and 'evidence' and result:
932
+ relevant_indices.extend([j + idx for idx in result['index']])
933
+ evidence_list.extend(result['evidence'])
934
+ except Exception as e:
935
+ print(f"Error parsing LLM response: {e}")
936
+
937
+ if relevant_indices:
938
+ return evidence.get_item_by_index(relevant_indices), evidence_list
939
+ else:
940
+ return ChatHistory(), []
941
+
942
+ def _read_and_verify_with_cache(self, qid:str, pool):
943
+ relevant_sess_ids = []
944
+
945
+ for sess in pool.sessions:
946
+ sess_id = sess['session_id']
947
+ if sess_id in qid2rel_sess_ids[qid]:
948
+ relevant_sess_ids.append(sess_id)
949
+
950
+ if len(relevant_sess_ids) > 0:
951
+ return pool.get_item_by_session_ids(relevant_sess_ids), []
952
+ else:
953
+ return ChatHistory(), []
954
+
955
+ def _plan(self, query: str, query_date: str, attempt_record: list, model_info) -> str:
956
+ if self.user_profile:
957
+ template = """
958
+ ### User profile: {user_profile}
959
+ ### Chat history topics: {topics}
960
+ ### User query: {query}
961
+ ### User query date: {query_date}
962
+ ### Previous attempts:
963
+ {strategies_info}
964
+ """
965
+ else:
966
+ template = """
967
+ ### Chat history topics: {topics}
968
+ ### User query: {query}
969
+ ### User query date: {query_date}
970
+ ### Previous attempts:
971
+ {strategies_info}
972
+ """
973
+
974
+ if attempt_record:
975
+ strategies_info = ""
976
+ for loop_num, entry in enumerate(attempt_record):
977
+ strategies_info += f"\nloop_iteration: {loop_num+1}\n"
978
+ strategies_info += "\n".join(entry.get('step_logs', []))
979
+ if 'n_retrieved_sess' in entry and 'evidence' in entry:
980
+ strategies_info += f"Retrieved {entry['n_retrieved_sess']} docs, observed_evidence: {entry['evidence']}"
981
+ if entry['n_retrieved_sess'] == 0:
982
+ strategies_info += f"Additional Instruction: Re-try without filter methods if the previous paln includes topics or time-filtering\n"
983
+ else:
984
+ strategies_info = "(No previous attempt exists)"
985
+
986
+ if self.user_profile:
987
+ prompt_filled = stg_prompt + template.format(
988
+ user_profile=self.user_profile,
989
+ topics=",".join(self.topics),
990
+ query=query,
991
+ query_date=query_date,
992
+ strategies_info=strategies_info)
993
+ else:
994
+ prompt_filled = stg_prompt + template.format(
995
+ topics=",".join(self.topics),
996
+ query=query,
997
+ query_date=query_date,
998
+ strategies_info=strategies_info)
999
+
1000
+ deployment_name, api_version = model_info
1001
+ completion = llm_call(
1002
+ deployment_name,
1003
+ api_version,
1004
+ prompt_filled,
1005
+ debug=self.debug,
1006
+ vllm=self.vllm,
1007
+ tritonai=self.tritonai,
1008
+ nvidia=self.nvidia,
1009
+ )
1010
+ self._track_usage('planning', completion)
1011
+ response_content = (completion.choices[0].message.content or "").strip()
1012
+ _plan = parse_json(response_content)
1013
+ if not _plan:
1014
+ print("[Warning] Failed to parse plan JSON — retrying once.")
1015
+ completion = llm_call(
1016
+ deployment_name,
1017
+ api_version,
1018
+ prompt_filled,
1019
+ debug=self.debug,
1020
+ vllm=self.vllm,
1021
+ tritonai=self.tritonai,
1022
+ nvidia=self.nvidia,
1023
+ )
1024
+ self._track_usage('planning', completion)
1025
+ response_content = (completion.choices[0].message.content or "").strip()
1026
+ _plan = parse_json(response_content)
1027
+ if not _plan:
1028
+ print("[Warning] Failed to parse plan JSON after retry — returning fallback plan.")
1029
+ _plan = {"answer": "none", "reason": "invalid JSON response", "topics": [], "strategy": []}
1030
+ return _plan
1031
+
1032
+ def _run_stage1(
1033
+ self,
1034
+ qid: str,
1035
+ question: str,
1036
+ question_date: str,
1037
+ top_k: int,
1038
+ model_info,
1039
+ haystack_sess_ids: List[str],
1040
+ date_lookup: Dict[str, str],
1041
+ semantic_ret_dict: dict,
1042
+ ) -> dict:
1043
+ """
1044
+ Stage 1: retrieve and evaluate using semantic memory only (summaries + facts).
1045
+
1046
+ Returns a dict with:
1047
+ is_answerable : bool
1048
+ answer : str | None (set when is_answerable is True)
1049
+ candidate_ids : list[str] (top-K session IDs from semantic retrieval)
1050
+ attempt_record : list
1051
+ """
1052
+ print(f"\t[Stage 1] Semantic memory retrieval for qid={qid}")
1053
+
1054
+ # --- 1a. Semantic embedding search ---
1055
+ candidate_ids = semantic_embedding_search(
1056
+ qid, haystack_sess_ids, semantic_ret_dict, top_k=top_k
1057
+ )
1058
+
1059
+ # hier_v2: skip plan / keyword / time_filter / is_answerable. Stage 1 is candidate-only.
1060
+ if self.hier_v2:
1061
+ print(f"\t[Stage 1 hier_v2] embedding-only candidates: {len(candidate_ids)}")
1062
+ return {
1063
+ "is_answerable": False,
1064
+ "answer": None,
1065
+ "candidate_ids": candidate_ids[:top_k],
1066
+ "attempt_record": [{
1067
+ "stage": "semantic_v2",
1068
+ "plan": {},
1069
+ "n_candidates": len(candidate_ids),
1070
+ "candidate_ids": candidate_ids[:top_k],
1071
+ }],
1072
+ }
1073
+
1074
+ # --- 1b. Plan for keywords / time filter (reuse existing _plan) ---
1075
+ plan = self._plan(question, question_date, [], model_info)
1076
+ print(json.dumps(plan, indent=4), flush=True)
1077
+
1078
+ # If the planner already has a direct answer, return it
1079
+ if "answer" in plan and plan["answer"].lower() != "none":
1080
+ return {
1081
+ "is_answerable": True,
1082
+ "answer": plan["answer"],
1083
+ "candidate_ids": candidate_ids,
1084
+ "attempt_record": [{"plan": plan, "stage": "semantic"}],
1085
+ }
1086
+
1087
+ # --- 1c. Semantic keyword search ---
1088
+ keyword_ids: List[str] = []
1089
+ for step in plan.get("strategy", []):
1090
+ if step.get("method") == "keyword":
1091
+ kws = step.get("keywords", [])
1092
+ matched = self.semantic_store.keyword_search(kws, haystack_sess_ids)
1093
+ print(f"\t\t** semantic keyword search **: {kws} -> {len(matched)} matches")
1094
+ keyword_ids.extend(sid for sid in matched if sid not in keyword_ids)
1095
+
1096
+ # --- 1d. Time filter on candidate set ---
1097
+ for step in plan.get("strategy", []):
1098
+ if self.no_time_filter:
1099
+ break # skip all time_filter steps
1100
+ if step.get("method") == "time_filter":
1101
+ if "time_range" not in step or len(step["time_range"]) != 2:
1102
+ continue
1103
+ start_str, end_str = step["time_range"]
1104
+ from datetime import datetime
1105
+ try:
1106
+ start_dt = datetime.fromisoformat(start_str)
1107
+ end_dt = datetime.fromisoformat(end_str)
1108
+ candidate_ids = [
1109
+ sid for sid in candidate_ids
1110
+ if sid in date_lookup and
1111
+ start_dt.date() <= EpisodicMemoryStore._parse_date(date_lookup[sid]).date() <= end_dt.date()
1112
+ ]
1113
+ keyword_ids = [
1114
+ sid for sid in keyword_ids
1115
+ if sid in date_lookup and
1116
+ start_dt.date() <= EpisodicMemoryStore._parse_date(date_lookup[sid]).date() <= end_dt.date()
1117
+ ]
1118
+ print(f"\t\t** semantic time_filter **: {start_str}..{end_str} -> "
1119
+ f"{len(candidate_ids)} embed, {len(keyword_ids)} keyword")
1120
+ except Exception as e:
1121
+ print(f"\t\t[WARN] time_filter parse error: {e}")
1122
+
1123
+ # --- 1e. Merge candidates (keyword union with embedding, preserve rank) ---
1124
+ all_candidate_ids: List[str] = list(candidate_ids)
1125
+ for sid in keyword_ids:
1126
+ if sid not in all_candidate_ids:
1127
+ all_candidate_ids.append(sid)
1128
+
1129
+ # Cap to top_k
1130
+ all_candidate_ids = all_candidate_ids[:top_k]
1131
+
1132
+ if not all_candidate_ids:
1133
+ print("\t[Stage 1] No candidates found in semantic memory.")
1134
+ return {
1135
+ "is_answerable": False,
1136
+ "answer": None,
1137
+ "candidate_ids": [],
1138
+ "attempt_record": [{"plan": plan, "stage": "semantic", "n_candidates": 0}],
1139
+ }
1140
+
1141
+ # --- 1f. Build semantic context string ---
1142
+ semantic_context_str = self.semantic_store.to_prompt(all_candidate_ids, date_lookup)
1143
+ print(f"\t[Stage 1] Built semantic context for {len(all_candidate_ids)} sessions "
1144
+ f"({len(semantic_context_str)} chars)")
1145
+
1146
+ # --- 1g. is_answerable check on semantic context ---
1147
+ accumulated_evidence = {"profile": [], "chat_clues": []}
1148
+ answerable_response = self.is_answerable(
1149
+ question, question_date,
1150
+ retrieved_sess=None,
1151
+ evidence=accumulated_evidence,
1152
+ model_info=model_info,
1153
+ context_str=semantic_context_str,
1154
+ )
1155
+ print(f"\t[Stage 1] is_answerable: {answerable_response}")
1156
+
1157
+ attempt_record = [{
1158
+ "stage": "semantic",
1159
+ "plan": plan,
1160
+ "n_candidates": len(all_candidate_ids),
1161
+ "candidate_ids": all_candidate_ids,
1162
+ "is_answerable": answerable_response.get("is_answerable", False),
1163
+ }]
1164
+
1165
+ if answerable_response.get("is_answerable"):
1166
+ return {
1167
+ "is_answerable": True,
1168
+ "answer": answerable_response.get("answer"),
1169
+ "candidate_ids": all_candidate_ids,
1170
+ "attempt_record": attempt_record,
1171
+ }
1172
+
1173
+ print(f"\t[Stage 1] Not answerable from semantic memory. "
1174
+ f"Info needed: {answerable_response.get('info_needed', [])}")
1175
+ return {
1176
+ "is_answerable": False,
1177
+ "answer": None,
1178
+ "candidate_ids": all_candidate_ids,
1179
+ "attempt_record": attempt_record,
1180
+ }
1181
+
1182
+ def run(self, qid:str, question: str, question_date: str, top_k: int, model_info, max_loops=3,
1183
+ semantic_ret_dict: dict = None, haystack_sess_ids: List[str] = None,
1184
+ date_lookup: Dict[str, str] = None, topic_lookup: Dict[str, List[str]] = None):
1185
+ accumulated_evidence = {"profile": [], "chat_clues": []}
1186
+ attempt_record = []
1187
+ loop_num = 0
1188
+
1189
+ # ----------------------------------------------------------------
1190
+ # Stage 1: Semantic memory — only runs when stores are provided
1191
+ # ----------------------------------------------------------------
1192
+ stage1_candidate_ids: List[str] = []
1193
+ if (self.semantic_store is not None
1194
+ and semantic_ret_dict is not None
1195
+ and haystack_sess_ids is not None):
1196
+ stage1_result = self._run_stage1(
1197
+ qid, question, question_date, top_k, model_info,
1198
+ haystack_sess_ids, date_lookup or {}, semantic_ret_dict,
1199
+ )
1200
+ attempt_record.extend(stage1_result["attempt_record"])
1201
+ stage1_candidate_ids = stage1_result["candidate_ids"]
1202
+
1203
+ if stage1_result["is_answerable"] and not self.no_early_answer:
1204
+ print(f"\t[Stage 1] Answered from semantic memory.")
1205
+ # Wrap answer in attempt_record format expected by caller
1206
+ if "answer" in stage1_result and stage1_result["answer"]:
1207
+ attempt_record[0]["plan"] = {
1208
+ **attempt_record[0].get("plan", {}),
1209
+ "answer": stage1_result["answer"],
1210
+ }
1211
+ return ChatHistory(), attempt_record
1212
+ if stage1_result["is_answerable"] and self.no_early_answer:
1213
+ print(f"\t[Stage 1] is_answerable=True but --no_early_answer set; proceeding to Stage-2.")
1214
+
1215
+ # hier_union: widen Stage-2 pool with flat-embedding top-K from the global GTE cache.
1216
+ # This makes hier a strict superset of flat by construction; targets the recall gap
1217
+ # (semantic-over-summary embeddings rank worse than full-session embeddings).
1218
+ if self.hier_union and qid in retrieved_data_dict:
1219
+ flat_ids = flat_embedding_top_k_ids(qid, haystack_sess_ids, self.hier_union_flat_k)
1220
+ before = len(stage1_candidate_ids)
1221
+ for sid in flat_ids:
1222
+ if sid not in stage1_candidate_ids:
1223
+ stage1_candidate_ids.append(sid)
1224
+ print(f"\t[hier_union] semantic_top_k={before} + flat_top_{self.hier_union_flat_k}={len(flat_ids)} -> union={len(stage1_candidate_ids)}")
1225
+ attempt_record.append({
1226
+ "stage": "hier_union",
1227
+ "plan": {},
1228
+ "n_semantic": before,
1229
+ "n_flat": len(flat_ids),
1230
+ "n_union": len(stage1_candidate_ids),
1231
+ })
1232
+
1233
+ # Stage 2: load episodic sessions only for top-K candidates
1234
+ if self.episodic_store is not None and stage1_candidate_ids:
1235
+ print(f"\t[Stage 2] Loading episodic memory for "
1236
+ f"{len(stage1_candidate_ids)} candidate sessions.")
1237
+ raw_sessions = self.episodic_store.get_raw_sessions(
1238
+ stage1_candidate_ids, date_lookup or {}, topic_lookup
1239
+ )
1240
+ self.chat_history = ChatHistory(sessions=raw_sessions)
1241
+ print(f"\t[Stage 2] Loaded {len(self.chat_history)} sessions into episodic pool.")
1242
+
1243
+ # hier_v2: skip the agent loop. Use raw turns of the candidate sessions directly.
1244
+ # Rationale: the agent loop's verification can reject all of a 20-session pool,
1245
+ # leaving empty retrieved. Strong models answer better from raw turns of K=20
1246
+ # semantic-selected sessions than from over-aggressive verification.
1247
+ if self.hier_v2:
1248
+ print(f"\t[hier_v2] bypassing agent loop; returning {len(self.chat_history)} candidate sessions as retrieved")
1249
+ return self.chat_history, attempt_record
1250
+
1251
+ pool = self.chat_history
1252
+ retrieved = ChatHistory()
1253
+
1254
+ while loop_num < max_loops:
1255
+ loop_num += 1
1256
+ if loop_num == 1 and qid in qid2plan:
1257
+ plan = qid2plan[qid]
1258
+ else:
1259
+ plan = self._plan(question, question_date, attempt_record, model_info)
1260
+ qid2plan[qid] = plan
1261
+ print(json.dumps(plan, indent=4), flush=True)
1262
+
1263
+ if "answer" in plan and not ("none" in plan["answer"].lower()):
1264
+ print(f"{qid}\t{question}\t{plan['answer']}", flush=True)
1265
+ return ChatHistory(), [{"plan": plan}]
1266
+
1267
+ # 2) Execute the plan -> retrieve candidates
1268
+ try:
1269
+ candidates, step_logs = self._execute_strategy(pool, plan, question)
1270
+ except Exception as e:
1271
+ print(f"[Error] Failed during _execute_strategy: {e}")
1272
+ candidates, step_logs = ChatHistory(), [f"Execution failed: {e}"]
1273
+
1274
+ if len(candidates) == 0:
1275
+ attempt_record.append({
1276
+ "loop_iteration": loop_num,
1277
+ "plan": plan,
1278
+ "evidence": accumulated_evidence,
1279
+ "n_candidates_sess": len(candidates),
1280
+ "n_verified_sess": 0,
1281
+ "n_pool": len(pool),
1282
+ "step_logs": step_logs,
1283
+ })
1284
+ continue
1285
+ else:
1286
+ remaining = set(pool.get_session_ids()) - set(candidates.get_session_ids())
1287
+ pool = self.chat_history.get_item_by_session_ids(remaining)
1288
+
1289
+ # 3) Verification Reading
1290
+ if qid in qid2rel_sess_ids:
1291
+ verified, evidence_list = self._read_and_verify_with_cache(qid, candidates)
1292
+ else:
1293
+ verified, evidence_list = self._read_and_verify(question, question_date, candidates, n_chunks=self.n_chunks)
1294
+ qid2rel_sess_ids[qid] = verified.get_session_ids()
1295
+
1296
+ if len(verified) == 0:
1297
+ attempt_record.append({
1298
+ "loop_iteration": loop_num,
1299
+ "plan": plan,
1300
+ "evidence": accumulated_evidence,
1301
+ "n_candidates_sess": len(candidates),
1302
+ "candidates_sess_ids": candidates.get_session_ids(),
1303
+ "n_verified_sess": len(verified),
1304
+ "verified_sess_ids": [],
1305
+ "n_pool": len(pool),
1306
+ "step_logs": step_logs,
1307
+ })
1308
+ continue
1309
+ else:
1310
+ retrieved.merge_rel_sess(verified)
1311
+
1312
+ for ev in evidence_list:
1313
+ if ev not in accumulated_evidence['chat_clues']:
1314
+ accumulated_evidence['chat_clues'].append(ev)
1315
+
1316
+ attempt_record.append({
1317
+ "loop_iteration": loop_num,
1318
+ "plan": plan,
1319
+ "evidence": accumulated_evidence,
1320
+ "n_candidates_sess": len(candidates),
1321
+ "candidates_sess_ids": candidates.get_session_ids(),
1322
+ "n_verified_sess": len(verified),
1323
+ "verified_sess_ids": verified.get_session_ids(),
1324
+ "n_retrieved_sess": len(retrieved),
1325
+ "retrieved_sess_ids": retrieved.get_session_ids(),
1326
+ "n_pool": len(pool),
1327
+ "step_logs": step_logs,
1328
+ })
1329
+
1330
+ # 4) Decide if continue or not
1331
+ if len(retrieved) > top_k:
1332
+ retrieved, _ = filter_out_by_embedding(retrieved, qid=qid, top_k=top_k)
1333
+
1334
+ answerable_response = self.is_answerable(question, question_date, retrieved, accumulated_evidence, model_info)
1335
+ if answerable_response["is_answerable"]:
1336
+ plan["answer"] = answerable_response["answer"]
1337
+ return retrieved, attempt_record
1338
+
1339
+ if len(pool) == 0:
1340
+ break
1341
+
1342
+ return retrieved, attempt_record
1343
+
1344
+ def _execute_strategy(self, pool, plan, question):
1345
+ step_logs: List[str] = []
1346
+
1347
+ # Start from all chat items
1348
+ if self.topic_filter and len(plan.get('topics', [])) > 0:
1349
+ pool = pool.get_item_by_topics(plan['topics'])
1350
+
1351
+ retrieved = ChatHistory()
1352
+
1353
+ strategy = plan["strategy"]
1354
+ for step in strategy:
1355
+ method = step.get("method")
1356
+ if method == "keyword":
1357
+ kws = step.get("keywords", [])
1358
+ matched = keyword_search(pool, kws)
1359
+ step_logs.append(f"Method: keyword - matched {len(matched)}/{len(pool)} using {kws}")
1360
+ if len(matched) > 0:
1361
+ retrieved.merge_rel_sess(matched)
1362
+ elif method == "embedding":
1363
+ top_k = 50
1364
+ matched = embedding_search(pool, qid, top_k=top_k)
1365
+ step_logs.append(f"Method: embedding - top_k={top_k}, matched {len(matched)}/{len(pool)}")
1366
+ if len(matched) > 0:
1367
+ retrieved.merge_rel_sess(matched)
1368
+ elif method == "time_filter":
1369
+ if self.no_time_filter:
1370
+ step_logs.append(f"Method: time_filter - skipped (--no_time_filter)")
1371
+ continue
1372
+ if 'time_range' not in step or len(step['time_range']) != 2:
1373
+ continue
1374
+ if len(retrieved) > 0:
1375
+ retrieved = time_filter(retrieved, start_date=step['time_range'][0], end_date=step['time_range'][1])
1376
+ else:
1377
+ retrieved = time_filter(pool, start_date=step['time_range'][0], end_date=step['time_range'][1])
1378
+ step_logs.append(f"Method: time_filter - kept {len(retrieved)}/{len(pool)} in {step['time_range'][0]}..{step['time_range'][1]}")
1379
+ #if len(matched) > 0:
1380
+ # retrieved.merge_rel_sess(matched)
1381
+ else:
1382
+ step_logs.append(f"unknown method: {method}")
1383
+
1384
+ if len(retrieved) > 100:
1385
+ top_k = 100
1386
+ retrieved = embedding_search(retrieved, qid, top_k=top_k)
1387
+ step_logs.append(
1388
+ f"too many sess ({len(pool)}) - embedding top_k={top_k} matched {len(retrieved)}/{len(pool)}"
1389
+ )
1390
+
1391
+ return retrieved, step_logs
1392
+
1393
+ def merge_rel_sess(self, new_sessions: ChatHistory):
1394
+ # Gather all current and new sessions in a dict keyed by session_id
1395
+ all_sessions = {s["session_id"]: s for s in self.rel_sess.sessions}
1396
+
1397
+ # add if new
1398
+ for s in new_sessions.sessions:
1399
+ if s["session_id"] not in all_sessions:
1400
+ all_sessions[s["session_id"]] = s
1401
+
1402
+ # Optional: sort sessions by timestamp for consistent ordering
1403
+ #merged_sessions = list(all_sessions.values())
1404
+ #merged_sessions.sort(key=lambda x: x["timestamp"])
1405
+
1406
+ # Reconstruct raw_data for new ChatHistory
1407
+ merged_raw_data = {
1408
+ "haystack_dates": [s["session_date"] for k, s in all_sessions.items()],
1409
+ "haystack_session_ids": [s["session_id"] for k, s in all_sessions.items()],
1410
+ "haystack_sessions": [s["session"] for k, s in all_sessions.items()],
1411
+ }
1412
+ self.rel_sess = ChatHistory(merged_raw_data)
1413
+
1414
+
1415
+ def merge_evidence(self, new_evidence: list):
1416
+ self.evidence = self.evidence + new_evidence
1417
+ print(f"\t\t Updated evidence: {self.evidence}")
1418
+
1419
+
1420
+ if __name__ == "__main__":
1421
+ parser = argparse.ArgumentParser()
1422
+ parser.add_argument('--in_file', type=str, required=True)
1423
+ parser.add_argument('--out_file', type=str, required=True)
1424
+ parser.add_argument('--model_name', type=str, required=True)
1425
+ parser.add_argument('--top_k', type=int, required=True)
1426
+ parser.add_argument('--debug', action='store_true', default=False)
1427
+ parser.add_argument('--vllm', action='store_true', default=False)
1428
+ parser.add_argument('--tritonai', action='store_true', default=False,
1429
+ help='Use OpenAI-compatible LiteLLM proxy for non-reading LLM calls (set TRITONAI_API_KEY)')
1430
+ parser.add_argument('--nvidia', action='store_true', default=False,
1431
+ help='Use NVIDIA inference API (set NV_API_KEY)')
1432
+ parser.add_argument('--vllm_reading', action='store_true', default=False,
1433
+ help='Use vLLM only for verification reading; all other LLM calls use the proprietary API')
1434
+ parser.add_argument('--n_chunks', type=int, default=10,
1435
+ help='Number of sessions per verification-reading LLM call (default 10)')
1436
+ parser.add_argument('--max_loops', type=int, default=3,
1437
+ help='Maximum retrieval/planning loops for agent mode (default 3)')
1438
+ parser.add_argument('--mode', type=str, default="agent", choices=['agent', 'embed', 'keyword'])
1439
+ parser.add_argument('--topic_filter', type=bool, default=True)
1440
+ parser.add_argument('--user_profile', action=argparse.BooleanOptionalAction, default=True,
1441
+ help='Include user profile in prompts (default: True, use --no-user_profile to disable)')
1442
+ parser.add_argument('--no_semantic', action='store_true', default=False,
1443
+ help='Skip semantic Stage 1; run episodic Stage 2 only on all haystack sessions')
1444
+ parser.add_argument('--no_time_filter', action='store_true', default=False,
1445
+ help='Disable time_filter steps in strategy execution (can reuse plan cache)')
1446
+ # Two-stage memory arguments
1447
+ parser.add_argument('--semantic_ret_cache', type=str, default=None,
1448
+ help='Path to semantic-gte retrieval log (JSONL) for Stage 1')
1449
+ parser.add_argument('--summary_file', type=str, default=None,
1450
+ help='Path to all_session_summary.json for SemanticMemoryStore')
1451
+ parser.add_argument('--facts_file', type=str, default=None,
1452
+ help='Path to all_session_user_facts.json for SemanticMemoryStore')
1453
+ parser.add_argument('--all_sessions_file', type=str, default=None,
1454
+ help='Path to all_sessions.json for lazy episodic loading')
1455
+ parser.add_argument('--no_save_cache', action='store_true', default=False,
1456
+ help='Disable saving plan/reading caches to disk after the run')
1457
+ parser.add_argument('--hier_v2', action='store_true', default=False,
1458
+ help='Stage 1 produces candidates only: skip early-answer return, semantic keyword expansion, time_filter, and is_answerable shortcut')
1459
+ parser.add_argument('--hier_union', action='store_true', default=False,
1460
+ help='hier mode: union Stage-1 semantic candidates with flat-embedding top-K and run agent loop on the merged pool')
1461
+ parser.add_argument('--hier_union_flat_k', type=int, default=20,
1462
+ help='How many flat-embedding top-K IDs to union into the Stage-2 pool (default 20)')
1463
+ parser.add_argument('--no_early_answer', action='store_true', default=False,
1464
+ help='Disable Stage-1 is_answerable early-return shortcut; always proceed to Stage-2 agent loop')
1465
+ parser.add_argument('--answer_prompt_v2', action='store_true', default=False,
1466
+ help='Use the v2 answer prompt with explicit guidance for aggregation, temporal reasoning, knowledge updates, and absence cases.')
1467
+ args = parser.parse_args()
1468
+
1469
+ # Rebind reading cache to include n_chunks so different chunk sizes get separate caches
1470
+ veri_reading_log_file = os.environ['reading_cache'] + f'_nchunks{args.n_chunks}'
1471
+ qid2rel_sess_ids = {}
1472
+ if os.path.exists(veri_reading_log_file):
1473
+ qid2rel_sess_ids = json.load(open(veri_reading_log_file))
1474
+ print(f'Reading cache: {veri_reading_log_file} ({len(qid2rel_sess_ids)} cached entries)')
1475
+
1476
+ in_data = json.load(open(args.in_file))
1477
+ top_k = args.top_k
1478
+ out_file = args.out_file
1479
+
1480
+ model_info = model_zoo[args.model_name]
1481
+ deployment_name, api_version = model_info
1482
+
1483
+ existings = set()
1484
+ retrieval_metric_list = []
1485
+ if os.path.exists(out_file):
1486
+ for line in open(out_file):
1487
+ obj = json.loads(line)
1488
+ existings.add(obj['question_id'])
1489
+ if 'retrieval_metric' in obj:
1490
+ retrieval_metric_list.append(obj['retrieval_metric'])
1491
+
1492
+ out_f = open(out_file, 'a')
1493
+
1494
+ ############# read meta files #####################
1495
+ qid2profiles = {}
1496
+ with open("metadata/generated_user_profile.json") as f:
1497
+ qid2profiles = json.load(f)
1498
+ sess2topic = {}
1499
+ with open("metadata/sessions_with_topic.json") as f:
1500
+ sess2topic = json.load(f)
1501
+
1502
+ # ----------------------------------------------------------------
1503
+ # Two-stage memory stores (optional; activated by CLI args)
1504
+ # ----------------------------------------------------------------
1505
+ semantic_store = None
1506
+ episodic_store = None
1507
+ semantic_ret_dict = None
1508
+
1509
+ if args.summary_file and args.facts_file:
1510
+ semantic_store = SemanticMemoryStore(args.summary_file, args.facts_file)
1511
+
1512
+ if args.all_sessions_file:
1513
+ episodic_store = EpisodicMemoryStore(args.all_sessions_file)
1514
+
1515
+ if args.semantic_ret_cache:
1516
+ print(f"Loading semantic retrieval cache from {args.semantic_ret_cache} ...")
1517
+ sem_ret_data = [json.loads(line) for line in open(args.semantic_ret_cache)]
1518
+ semantic_ret_dict = {x['question_id']: x for x in sem_ret_data}
1519
+ print(f" Loaded {len(semantic_ret_dict)} entries.")
1520
+
1521
+ retrieval_metric_list = []
1522
+ for di, entry in enumerate(in_data):
1523
+ item_start_time = time.time()
1524
+ qid, question, q_date = entry['question_id'], entry['question'], entry['question_date']
1525
+ q_date = entry['question_date']
1526
+
1527
+ if qid in existings:
1528
+ continue
1529
+
1530
+ haystack_sess_ids = entry['haystack_session_ids']
1531
+ haystack_topics = [sess2topic.get(sid, {}).get('category', []) for sid in haystack_sess_ids]
1532
+ date_lookup = dict(zip(haystack_sess_ids, entry['haystack_dates']))
1533
+ topic_lookup = dict(zip(haystack_sess_ids, haystack_topics))
1534
+
1535
+ # Build ChatHistory: lazily from episodic store (two-stage) or from raw data (legacy)
1536
+ if episodic_store is not None:
1537
+ # Two-stage mode: start with full haystack loaded from episodic store
1538
+ # (Stage 1 will narrow this down before Stage 2 runs)
1539
+ raw_sessions = episodic_store.get_raw_sessions(
1540
+ haystack_sess_ids, date_lookup, topic_lookup
1541
+ )
1542
+ chat_history = ChatHistory(sessions=raw_sessions)
1543
+ else:
1544
+ chat_history = ChatHistory({
1545
+ "haystack_dates": entry['haystack_dates'],
1546
+ "haystack_session_ids": entry['haystack_session_ids'],
1547
+ "haystack_sessions": entry['haystack_sessions'],
1548
+ "haystack_topics": haystack_topics,
1549
+ })
1550
+
1551
+ topic_set = set()
1552
+ for ht in haystack_topics:
1553
+ topic_set.update(ht)
1554
+
1555
+ if args.user_profile:
1556
+ # user profile
1557
+ temp_qid = qid
1558
+ if '_q_' in qid:
1559
+ temp_qid = qid.split("_q_")[0]
1560
+ user_profile = qid2profiles[temp_qid]
1561
+ agent = RetrievalAgent(
1562
+ chat_history,
1563
+ list(topic_set),
1564
+ user_profile=user_profile,
1565
+ debug=args.debug,
1566
+ vllm=args.vllm,
1567
+ vllm_reading=args.vllm_reading,
1568
+ tritonai=args.tritonai,
1569
+ nvidia=args.nvidia,
1570
+ n_chunks=args.n_chunks,
1571
+ topic_filter=args.topic_filter,
1572
+ no_time_filter=args.no_time_filter,
1573
+ semantic_store=None if args.no_semantic else semantic_store,
1574
+ episodic_store=episodic_store,
1575
+ hier_v2=args.hier_v2,
1576
+ hier_union=args.hier_union,
1577
+ hier_union_flat_k=args.hier_union_flat_k,
1578
+ no_early_answer=args.no_early_answer,
1579
+ )
1580
+ else:
1581
+ agent = RetrievalAgent(
1582
+ chat_history,
1583
+ list(topic_set),
1584
+ debug=args.debug,
1585
+ vllm=args.vllm,
1586
+ vllm_reading=args.vllm_reading,
1587
+ tritonai=args.tritonai,
1588
+ nvidia=args.nvidia,
1589
+ n_chunks=args.n_chunks,
1590
+ topic_filter=args.topic_filter,
1591
+ no_time_filter=args.no_time_filter,
1592
+ semantic_store=None if args.no_semantic else semantic_store,
1593
+ episodic_store=episodic_store,
1594
+ hier_v2=args.hier_v2,
1595
+ hier_union=args.hier_union,
1596
+ hier_union_flat_k=args.hier_union_flat_k,
1597
+ no_early_answer=args.no_early_answer,
1598
+ )
1599
+
1600
+ try:
1601
+ if args.mode == 'embed':
1602
+ final_sess = embedding_search(chat_history, qid, top_k=top_k)
1603
+ attempt_record = [{"plan": {"answer": "none", "reason": "embedding retrieval only"}}]
1604
+ elif args.mode == 'keyword':
1605
+ keywords = generate_keywords(question, deployment_name, api_version,
1606
+ debug=args.debug, vllm=args.vllm,
1607
+ tritonai=args.tritonai, nvidia=args.nvidia)
1608
+ final_sess = keyword_search(chat_history, keywords=keywords)
1609
+ attempt_record = [{"plan": {"answer": "none", "reason": "keyword retrieval only"}}]
1610
+ else: # agent
1611
+ final_sess, attempt_record = agent.run(
1612
+ qid, question, q_date, top_k, model_info,
1613
+ max_loops=args.max_loops,
1614
+ semantic_ret_dict=semantic_ret_dict,
1615
+ haystack_sess_ids=haystack_sess_ids,
1616
+ date_lookup=date_lookup,
1617
+ topic_lookup=topic_lookup,
1618
+ )
1619
+
1620
+ if len(attempt_record) == 1 and "answer" in attempt_record[0]["plan"] and not ("none" in attempt_record[0]["plan"]["answer"].lower()):
1621
+ answer = attempt_record[0]["plan"]["answer"]
1622
+ token_budget = agent.get_token_budget() if args.mode == 'agent' else {}
1623
+ wall_time_sec = time.time() - item_start_time
1624
+
1625
+ print(json.dumps({"q_idx": di, 'question_id': qid, 'question': entry['question'],
1626
+ 'answer': answer, 'n_retrieved': len(final_sess),
1627
+ 'wall_time_sec': round(wall_time_sec, 3)}, indent=4), flush=True)
1628
+ print(json.dumps({"q_idx": di, 'question_id': qid,
1629
+ 'hypothesis': answer,
1630
+ "attempt_record": attempt_record,
1631
+ "token_budget": token_budget,
1632
+ "wall_time_sec": wall_time_sec}), file=out_f, flush=True)
1633
+ else:
1634
+ if len(final_sess) > top_k and retrieved_log_file is not None:
1635
+ final_top_k_sess, _ = filter_out_by_embedding(final_sess, qid=qid, top_k=top_k)
1636
+ retrieved_str = final_top_k_sess.to_prompt(granularity="session", _format="json")
1637
+ else:
1638
+ retrieved_str = final_sess.to_prompt(granularity="session", _format="json")
1639
+
1640
+ if args.answer_prompt_v2:
1641
+ answer_prompt_template = (
1642
+ "You are answering a question using a list of chat-session transcripts between the user and an assistant.\n"
1643
+ "\n"
1644
+ "How to answer:\n"
1645
+ "1. Scan ALL retrieved sessions in chronological order. The SESSION DATE on each transcript is when that conversation occurred. The Current Date below is when the question was asked, not when events happened.\n"
1646
+ "2. Identify every session containing a candidate fact. If sessions conflict, prefer the most RECENT session that addresses the same fact (knowledge update).\n"
1647
+ "3. For aggregation questions ('how many', 'list all', 'between X and Y'), enumerate matches across ALL relevant sessions; do not stop at the first.\n"
1648
+ "4. For temporal queries ('last Friday', 'two weeks ago'), resolve the relative date against the SESSION DATE of the session that uses that phrase, not the Current Date.\n"
1649
+ "5. If the retrieved sessions do NOT contain the answer, reply exactly 'Insufficient information in retrieved sessions.' Do not fabricate.\n"
1650
+ "6. Be terse: state the direct answer first, then one short sentence citing the session date(s) you relied on.\n"
1651
+ "\n"
1652
+ "Chat history sessions:\n"
1653
+ "\n"
1654
+ "{}\n"
1655
+ "\n"
1656
+ "Current Date: {}\n"
1657
+ "Question: {}\n"
1658
+ "Answer:"
1659
+ )
1660
+ else:
1661
+ answer_prompt_template = "I will give you several chat history sessions between you and a user. Please answer the question given the information.\n\n\nChat history sessions:\n\n{}\n\nCurrent Date: {}\nQuestion: {}\nAnswer:"
1662
+ answer_prompt = answer_prompt_template.format(retrieved_str, entry['question_date'], entry['question'])
1663
+
1664
+ completion = llm_call(
1665
+ deployment_name,
1666
+ api_version,
1667
+ answer_prompt,
1668
+ debug=args.debug,
1669
+ vllm=args.vllm,
1670
+ tritonai=args.tritonai,
1671
+ nvidia=args.nvidia,
1672
+ )
1673
+ answer = (completion.choices[0].message.content or "").strip()
1674
+
1675
+ if args.mode == 'agent':
1676
+ agent._track_usage('final_answer', completion)
1677
+ token_budget = agent.get_token_budget() if args.mode == 'agent' else {}
1678
+
1679
+ retrieval_metric = {}
1680
+ if len(final_sess) > 0 and retrieved_log_file is not None:
1681
+ sess_sorted = embedding_search(final_sess, qid, top_k=20)
1682
+ sess_id_sorted = sess_sorted.get_session_ids()
1683
+
1684
+ for topk in [5, 10, 20, 30]:
1685
+ recall_any, recall_all = evaluate_retrieval(sess_id_sorted[:topk], entry['answer_session_ids'])
1686
+ retrieval_metric.update({
1687
+ 'recall_any@{}'.format(topk): recall_any,
1688
+ 'recall_all@{}'.format(topk): recall_all
1689
+ })
1690
+ retrieval_metric_list.append(retrieval_metric)
1691
+ print_average_metrics(retrieval_metric_list)
1692
+
1693
+ print(json.dumps({"q_idx": di, 'n_prompt_tok': completion.usage.prompt_tokens,
1694
+ 'n_completion_tok': completion.usage.completion_tokens,
1695
+ 'hypothesis': answer,
1696
+ 'wall_time_sec': round(time.time() - item_start_time, 3)}), flush=True)
1697
+ print(json.dumps({"q_idx": di, 'question_id': qid,
1698
+ 'hypothesis': answer,
1699
+ 'n_prompt_tok': completion.usage.prompt_tokens,
1700
+ 'n_completion_tok': completion.usage.completion_tokens,
1701
+ "attempt_record": attempt_record,
1702
+ "retrieved_sess_ids": final_sess.get_session_ids(),
1703
+ "retrieval_metric": retrieval_metric,
1704
+ "token_budget": token_budget,
1705
+ "wall_time_sec": time.time() - item_start_time}), file=out_f, flush=True)
1706
+ except Exception as e:
1707
+ print(f"[ERROR] q_idx={di} qid={qid} failed: {e}", flush=True)
1708
+ continue
1709
+
1710
+ ############# save cache ##########################
1711
+
1712
+ if not args.no_save_cache:
1713
+ with open(plan_cache_file, "w") as fw:
1714
+ json.dump(qid2plan, fw, indent=2)
1715
+
1716
+ with open(veri_reading_log_file, "w") as fw:
1717
+ json.dump(qid2rel_sess_ids, fw, indent=2)
memory/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from memory.episodic_store import EpisodicMemoryStore
2
+ from memory.semantic_store import SemanticMemoryStore
memory/episodic_store.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ episodic_store.py
3
+
4
+ Lazy loader for all_sessions.json. Keeps the full raw-turn data in memory
5
+ and returns sessions on demand by session ID, avoiding loading all sessions
6
+ into ChatHistory upfront.
7
+ """
8
+
9
+ import json
10
+ from datetime import datetime
11
+ from typing import Dict, List, Optional
12
+
13
+
14
+ class EpisodicMemoryStore:
15
+ def __init__(self, all_sessions_path: str):
16
+ print(f"[EpisodicMemoryStore] Loading {all_sessions_path} ...")
17
+ with open(all_sessions_path) as f:
18
+ self._data: Dict[str, List] = json.load(f)
19
+ print(f"[EpisodicMemoryStore] Loaded {len(self._data)} sessions.")
20
+
21
+ @staticmethod
22
+ def _parse_date(date_str: str) -> datetime:
23
+ """Convert '2023/04/10 (Mon) 17:50' to datetime."""
24
+ date_part = date_str.split('(')[0].strip()
25
+ time_part = date_str.split(')')[-1].strip()
26
+ return datetime.strptime(date_part + time_part, "%Y/%m/%d%H:%M")
27
+
28
+ def get_raw_sessions(
29
+ self,
30
+ sess_ids: List[str],
31
+ date_lookup: Dict[str, str],
32
+ topic_lookup: Optional[Dict[str, List[str]]] = None,
33
+ ) -> List[dict]:
34
+ """
35
+ Return a list of session dicts compatible with ChatHistory(sessions=...).
36
+
37
+ Args:
38
+ sess_ids: Ordered list of session IDs to load.
39
+ date_lookup: Mapping sess_id -> date string (e.g. '2023/04/10 (Mon) 17:50').
40
+ topic_lookup: Optional mapping sess_id -> list of topic strings.
41
+
42
+ Returns:
43
+ List of session dicts, each with keys:
44
+ session_id, session_date, session, topic, timestamp
45
+ """
46
+ sessions = []
47
+ for sid in sess_ids:
48
+ if sid not in self._data:
49
+ continue
50
+ date_str = date_lookup.get(sid, "")
51
+ try:
52
+ ts = self._parse_date(date_str) if date_str else datetime.min
53
+ except Exception:
54
+ ts = datetime.min
55
+ sessions.append({
56
+ "session_id": sid,
57
+ "session_date": date_str,
58
+ "session": self._data[sid],
59
+ "topic": topic_lookup.get(sid, []) if topic_lookup else [],
60
+ "timestamp": ts,
61
+ })
62
+ return sessions
memory/semantic_store.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ semantic_store.py
3
+
4
+ Wrapper around all_session_summary.json and all_session_user_facts.json.
5
+ Provides:
6
+ - keyword_search(): find sessions whose semantic text contains given keywords
7
+ - to_prompt(): format semantic context for LLM consumption
8
+ - get_text(): return raw semantic text for a session (for embedding/search)
9
+ """
10
+
11
+ import json
12
+ from typing import Dict, List, Optional
13
+
14
+
15
+ class SemanticMemoryStore:
16
+ def __init__(self, summary_path: str, facts_path: str):
17
+ print(f"[SemanticMemoryStore] Loading {summary_path} ...")
18
+ with open(summary_path) as f:
19
+ self._summaries: Dict[str, dict] = json.load(f)
20
+
21
+ print(f"[SemanticMemoryStore] Loading {facts_path} ...")
22
+ with open(facts_path) as f:
23
+ self._facts: Dict[str, list] = json.load(f)
24
+
25
+ print(f"[SemanticMemoryStore] Loaded {len(self._summaries)} summaries, "
26
+ f"{len(self._facts)} fact entries.")
27
+
28
+ def get_summary(self, sess_id: str) -> str:
29
+ """Return the session-level summary string, or empty string."""
30
+ entry = self._summaries.get(sess_id, {})
31
+ return entry.get("session_summary", "").strip()
32
+
33
+ def get_facts_text(self, sess_id: str) -> str:
34
+ """Return user facts as a single joined string, or empty string."""
35
+ fact_list = self._facts.get(sess_id, [])
36
+ if not fact_list:
37
+ return ""
38
+ return " ".join(
39
+ f["user-info"] for f in fact_list
40
+ if isinstance(f, dict) and f.get("user-info")
41
+ ).strip()
42
+
43
+ def get_text(self, sess_id: str) -> str:
44
+ """Return summary + facts combined (for keyword search or display)."""
45
+ parts = [self.get_summary(sess_id), self.get_facts_text(sess_id)]
46
+ return " ".join(p for p in parts if p)
47
+
48
+ def keyword_search(self, keywords: List[str], haystack_sess_ids: List[str]) -> List[str]:
49
+ """
50
+ Search semantic text (summary + facts) of the given sessions for any keyword.
51
+
52
+ Returns:
53
+ List of matching session IDs (preserving haystack order).
54
+ """
55
+ matched = []
56
+ kws_lower = [kw.lower() for kw in keywords if kw]
57
+ for sid in haystack_sess_ids:
58
+ text = self.get_text(sid).lower()
59
+ if any(kw in text for kw in kws_lower):
60
+ matched.append(sid)
61
+ return matched
62
+
63
+ def to_prompt(self, sess_ids: List[str], date_lookup: Optional[Dict[str, str]] = None) -> str:
64
+ """
65
+ Format semantic context for these sessions as a prompt string.
66
+
67
+ Each session block:
68
+ Session Date: <date>
69
+ Summary: <session_summary>
70
+ User Facts: <fact1>; <fact2>; ...
71
+ """
72
+ lines = []
73
+ for sid in sess_ids:
74
+ date_str = date_lookup.get(sid, "") if date_lookup else ""
75
+ summary = self.get_summary(sid)
76
+ facts_text = self.get_facts_text(sid)
77
+
78
+ block = f"Session ID: {sid}"
79
+ if date_str:
80
+ block += f"\nSession Date: {date_str}"
81
+ if summary:
82
+ block += f"\nSummary: {summary}"
83
+ if facts_text:
84
+ block += f"\nUser Facts: {facts_text}"
85
+ lines.append(block)
86
+
87
+ return "\n\n".join(lines)
model_zoo.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_zoo = {
2
+ # OpenAI / Azure-hosted models (deployment string, api_version)
3
+ 'gpt-5': ("gpt-5_2025-08-07", "2024-12-01-preview"),
4
+ "gpt-4.1-azure": ("gpt-4.1_2025-04-14", "2025-04-01-preview"),
5
+ 'gpt-4o': ('gpt-4o_2024-11-20', '2024-10-21'),
6
+ 'gpt-4o-mini': ("gpt-4o-mini", ""),
7
+ 'gpt-5-openai': ("gpt-5", ""),
8
+ 'gpt-5-mini-openai': ("gpt-5-mini", ""),
9
+
10
+ # vLLM-hosted models (OpenAI-compatible server)
11
+ 'Qwen3-30B-A3B-Instruct-2507': ("Qwen/Qwen3-30B-A3B-Instruct-2507", ""),
12
+ 'Qwen3-VL-30B-A3B-Instruct': ("Qwen3-VL-30B-A3B-Instruct", ""),
13
+
14
+ # Anthropic models via direct Anthropic API (uses ANTHROPIC_API_KEY)
15
+ 'claude-opus-4-6': ("claude-opus-4-6", ""),
16
+ 'claude-sonnet-4-6': ("claude-sonnet-4-6", ""),
17
+
18
+ # Anthropic / DeepSeek via an OpenAI-compatible LiteLLM proxy
19
+ # (uses LITELLM_API_KEY; selected by main.py's --tritonai flag)
20
+ 'claude-opus-4-6-tritonai': ("us.anthropic.claude-opus-4-6-v1", ""),
21
+ 'claude-sonnet-4-6-tritonai': ("us.anthropic.claude-sonnet-4-6-v1", ""),
22
+ 'deepseek-r1-tritonai': ("us.deepseek.r1-v1:0", ""),
23
+
24
+ # Models served via an OpenAI-compatible inference API (uses NV_API_KEY)
25
+ 'gpt-5.1': ("openai/openai/gpt-5.1", ""),
26
+ 'gpt-5.2': ("openai/openai/gpt-5.2", ""),
27
+ 'gpt-5.5': ("openai/openai/gpt-5.5", ""),
28
+ 'gpt-4.1': ("us/azure/openai/gpt-4.1", ""),
29
+ 'Qwen3.5-397B-A17B': ("nvidia/qwen/qwen3-5-397b-a17b", ""),
30
+ 'Kimi-K2.6': ("nvidia/moonshotai/kimi-k2.6", ""),
31
+ }
prompts/agentic_retrieval_prompt.txt ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are an intelligent assistant for memory retrieval.
2
+
3
+ Your goal is to **design and refine retrieval strategies** for answering a user’s query by leveraging the user’s memory resources (profile, topic, chat history, and other contextual knowledge).
4
+
5
+ ### Core Principles
6
+
7
+ 1. Decision-first retrieval: Decide whether retrieval is necessary.
8
+
9
+ * If the query can be answered **directly from the user’s profile**, retrieval is unnecessary.
10
+ * Otherwise, retrieval strategies must be proposed.
11
+ * If there is some useful information in user's profile, keep the information and continue to retrieve.
12
+
13
+ 2. Identify relevant topics:
14
+
15
+ * Given the topic list from user's chat history, identify topics related to the query.
16
+ * The topics will be used to narrow down search space. Be inclusive but not too general.
17
+
18
+ 3. Multi-method retrieval: Multiple retrieval methods may be combined:
19
+
20
+ * **Keyword-based retrieval** for unique names, places, or identifiers.
21
+ * **Embedding-based semantic search** when the query is vague, abstract, or conversational.
22
+ * **Time-based filtering** ONLY when the query references dates, ranges, or relative temporal expressions (e.g., “last week,” “yesterday”), use the given `query_date` to resolve them into precise ISO 8601 ranges.
23
+
24
+ 4. Loop-aware evidence collection: Retrieval may occur in **multiple iterations (loops)**. At each loop:
25
+
26
+ * Collect **evidence** from:
27
+ * User profile (static attributes like age, location, job, preferences)
28
+ * Topics (higher-level semantic indexing)
29
+ * Raw chat sessions (extract **key clue sentences**)
30
+ * Incorporate **previous retrieval attempts** (if provided) and refine strategy. If your previous attempt failed to retreive relevant sessions or evidneces, try without filter methods such as topic and time filters.
31
+
32
+ 5. Consistent JSON output: All outputs must follow the unified schema to enable downstream automation.
33
+
34
+ ---
35
+
36
+ ## JSON Output Schema
37
+
38
+ ```json
39
+ {
40
+ "query": "<original user query>",
41
+ "query_date": "<ISO 8601 date string>",
42
+ "loop_iteration": <integer, starts at 1>,
43
+ "retrieval_decision": "<none | retrieval_required>",
44
+ "answer": "<none | answer if possilbe>",
45
+ "topics": ["<list of relevant topics>"],
46
+ "strategy": [
47
+ {
48
+ "method": "<keyword | time_filter | embedding>",
49
+ "conditions": "<why this method is chosen>",
50
+ "keywords": ["<list of keywords>"],
51
+ "time_range": ["<start_date>", "<end_date>"]
52
+ }
53
+ ],
54
+ "evidence": {
55
+ "profile": ["<relevant profile snippets>"],
56
+ "chat_clues": ["<list of key sentences extracted from chat history>"]
57
+ },
58
+ "previous_attempts": [
59
+ {
60
+ "loop_iteration": <integer>,
61
+ "strategy": [ ... ],
62
+ "evidence": { ... },
63
+ "outcome": "<insufficient | useful | final_answer_ready>"
64
+ }
65
+ ]
66
+ }
67
+ ```
68
+
69
+ ---
70
+
71
+ ## Examples
72
+
73
+ ### Example 1: No retrieval needed (profile sufficient)
74
+
75
+ ```json
76
+ {
77
+ "query": "Where do I live?",
78
+ "query_date": "2025-08-29",
79
+ "loop_iteration": 1,
80
+ "retrieval_decision": "none",
81
+ "answer": "San Diego, California",
82
+ "topics": [],
83
+ "strategy": [],
84
+ "evidence": {
85
+ "profile": ["User lives in San Diego, California."],
86
+ "chat_clues": []
87
+ },
88
+ "previous_attempts": []
89
+ }
90
+ ```
91
+
92
+ ---
93
+
94
+ ### Example 2: Keyword strategy
95
+
96
+ ```json
97
+ {
98
+ "query": "Who did I go to the Grand Canyon with?",
99
+ "query_date": "2025-08-29",
100
+ "loop_iteration": 1,
101
+ "retrieval_decision": "retrieval_required",
102
+ "answer": "none",
103
+ "topics": ["Travel & Transportation", "Family & Relationships", "Personal Development"],
104
+ "strategy": [
105
+ {
106
+ "method": "keyword",
107
+ "conditions": "Query contains specific event and place name.",
108
+ "keywords": ["Grand Canyon"]
109
+ }
110
+ ],
111
+ "evidence": {
112
+ "profile": [],
113
+ "chat_clues": []
114
+ },
115
+ "previous_attempts": []
116
+ }
117
+ ```
118
+
119
+ ---
120
+
121
+ ### Example 3: Time-filter strategy with relative date
122
+
123
+ Query: How many miles have I hiked in the past two weeks?
124
+ Query date: 2023/05/22 (Mon) 23:31
125
+
126
+ ```json
127
+ {
128
+ "query": "How many miles have I hiked in the past two weeks?",
129
+ "query_date": "2023-05-22 (Mon) 23:31",
130
+ "loop_iteration": 1,
131
+ "retrieval_decision": "retrieval_required",
132
+ "answer": "none",
133
+ "topics": ["Sports & Fitness", "Health & Wellness", "Travel & Transportation", "Personal Development"],
134
+ "strategy": [
135
+ {
136
+ "method": "time_filter",
137
+ "conditions": "Temporal phrase 'past two weeks' resolved using query_date.",
138
+ "time_range": ["2023-05-08", "2023-05-22"]
139
+ }
140
+ ],
141
+ "evidence": {
142
+ "profile": [],
143
+ "chat_clues": []
144
+ },
145
+ "previous_attempts": [],
146
+ }
147
+ ```
148
+
149
+ ---
150
+
151
+ ### Example 4: Keyword + Embedding
152
+
153
+ ```json
154
+ {
155
+ "query": "What did my doctor say about my back pain treatment options?",
156
+ "query_date": "2025-08-29",
157
+ "loop_iteration": 1,
158
+ "retrieval_decision": "retrieval_required",
159
+ "answer": "none",
160
+ "topics": ["Health & Wellness", "Work & Career"],
161
+ "strategy": [
162
+ {
163
+ "method": "keyword",
164
+ "conditions": "Query contains specific medical terms that should be matched exactly.",
165
+ "keywords": ["doctor", "back pain", "treatment"]
166
+ },
167
+ {
168
+ "method": "embedding",
169
+ "conditions": "Query involves medical advice and recommendations which may be expressed in various conversational ways in chat history."
170
+ }
171
+ ],
172
+ "evidence": {
173
+ "profile": [],
174
+ "chat_clues": []
175
+ },
176
+ "previous_attempts": []
177
+ }
178
+ ```
179
+
180
+ ### Example 5: Multi-method strategy with loop refinement
181
+
182
+ Loop 1 outcome was insufficient, Loop 2 continues with broader retrieval.
183
+
184
+ ```json
185
+ {
186
+ "query": "What were we discussing about public issues?",
187
+ "query_date": "2025-08-29",
188
+ "loop_iteration": 2,
189
+ "retrieval_decision": "retrieval_required",
190
+ "answer": "none",
191
+ "topics": ["Government & Politics", "Environment & Sustainability", "Legal"],
192
+ "strategy": [
193
+ {
194
+ "method": "embedding",
195
+ "conditions": "To capture abstract conversational references."
196
+ }
197
+ ],
198
+ "evidence": {
199
+ "profile": [],
200
+ "chat_clues": [
201
+ "User debated about climate policy impacts.",
202
+ "Conversation on local housing regulations was tagged as 'public issues'."
203
+ ]
204
+ },
205
+ "previous_attempts": [
206
+ {
207
+ "loop_iteration": 1,
208
+ "strategy": [
209
+ {"method": "keyword", "conditions": "Query asks about 'public issue'", "keywords": ["public issue", "public issues", "issue"]}
210
+ ],
211
+ "evidence": {
212
+ "profile": [],
213
+ "chat_clues": []
214
+ },
215
+ "outcome": "insufficient"
216
+ }
217
+ ]
218
+ }
219
+ ```
220
+
221
+ Use clear reasoning for how each method applies (or not) to the specific query. Be concise and precise.
222
+
223
+ **Do not include strategies in the strategy list unless they are needed for this query. Omit unused methods.**
224
+ **Always use the provided `query_date` to resolve any relative dates in the query.**
225
+ **Strictly follow the given JSON output format. Return only **one** JSON output.**
226
+ **Use time-based filtering ONLY when the query references dates, ranges, or relative temporal expressions (e.g., "last week," "yesterday," "in last 3 month")
prompts/agentic_retrieval_prompt_wo_profile.txt ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are an intelligent assistant for memory retrieval.
2
+
3
+ Your goal is to **design and refine retrieval strategies** for answering a user’s query by leveraging the user’s memory resources (topic, chat history, and other contextual knowledge).
4
+
5
+ ### Core Principles
6
+
7
+ 1. Decision-first retrieval: Decide whether retrieval is necessary.
8
+
9
+ * If the query can be answered **directly**, retrieval is unnecessary.
10
+ * Otherwise, retrieval strategies must be proposed.
11
+
12
+ 2. Identify relevant topics:
13
+
14
+ * Given the topic list from user's chat history, identify topics related to the query.
15
+ * The topics will be used to narrow down search space. Be inclusive but not too general.
16
+
17
+ 3. Multi-method retrieval: Multiple retrieval methods may be combined:
18
+
19
+ * **Keyword-based retrieval** for unique names, places, or identifiers.
20
+ * **Embedding-based semantic search** when the query is vague, abstract, or conversational.
21
+ * **Time-based filtering** ONLY when the query references dates, ranges, or relative temporal expressions (e.g., “last week,” “yesterday”), use the given `query_date` to resolve them into precise ISO 8601 ranges.
22
+
23
+ 4. Loop-aware evidence collection: Retrieval may occur in **multiple iterations (loops)**. At each loop:
24
+
25
+ * Collect **evidence** from:
26
+ * Topics (higher-level semantic indexing)
27
+ * Raw chat sessions (extract **key clue sentences**)
28
+ * Incorporate **previous retrieval attempts** (if provided) and refine strategy. If your previous attempt failed to retreive relevant sessions or evidneces, try without filter methods such as topic and time filters.
29
+
30
+ 5. Consistent JSON output: All outputs must follow the unified schema to enable downstream automation.
31
+
32
+ ---
33
+
34
+ ## JSON Output Schema
35
+
36
+ ```json
37
+ {
38
+ "query": "<original user query>",
39
+ "query_date": "<ISO 8601 date string>",
40
+ "loop_iteration": <integer, starts at 1>,
41
+ "retrieval_decision": "<none | retrieval_required>",
42
+ "answer": "<none | answer if possilbe>",
43
+ "topics": ["<list of relevant topics>"],
44
+ "strategy": [
45
+ {
46
+ "method": "<keyword | time_filter | embedding>",
47
+ "conditions": "<why this method is chosen>",
48
+ "keywords": ["<list of keywords>"],
49
+ "time_range": ["<start_date>", "<end_date>"]
50
+ }
51
+ ],
52
+ "evidence": {
53
+ "profile": [],
54
+ "chat_clues": ["<list of key sentences extracted from chat history>"]
55
+ },
56
+ "previous_attempts": [
57
+ {
58
+ "loop_iteration": <integer>,
59
+ "strategy": [ ... ],
60
+ "evidence": { ... },
61
+ "outcome": "<insufficient | useful | final_answer_ready>"
62
+ }
63
+ ]
64
+ }
65
+ ```
66
+
67
+ ---
68
+
69
+ ## Examples
70
+
71
+ ### Example 1: Keyword strategy
72
+
73
+ ```json
74
+ {
75
+ "query": "Who did I go to the Grand Canyon with?",
76
+ "query_date": "2025-08-29",
77
+ "loop_iteration": 1,
78
+ "retrieval_decision": "retrieval_required",
79
+ "answer": "none",
80
+ "topics": ["Travel & Transportation", "Family & Relationships", "Personal Development"],
81
+ "strategy": [
82
+ {
83
+ "method": "keyword",
84
+ "conditions": "Query contains specific event and place name.",
85
+ "keywords": ["Grand Canyon"]
86
+ }
87
+ ],
88
+ "evidence": {
89
+ "profile": [],
90
+ "chat_clues": []
91
+ },
92
+ "previous_attempts": []
93
+ }
94
+ ```
95
+
96
+ ---
97
+
98
+ ### Example 2: Time-filter strategy with relative date
99
+
100
+ Query: How many miles have I hiked in the past two weeks?
101
+ Query date: 2023/05/22 (Mon) 23:31
102
+
103
+ ```json
104
+ {
105
+ "query": "How many miles have I hiked in the past two weeks?",
106
+ "query_date": "2023-05-22 (Mon) 23:31",
107
+ "loop_iteration": 1,
108
+ "retrieval_decision": "retrieval_required",
109
+ "answer": "none",
110
+ "topics": ["Sports & Fitness", "Health & Wellness", "Travel & Transportation", "Personal Development"],
111
+ "strategy": [
112
+ {
113
+ "method": "time_filter",
114
+ "conditions": "Temporal phrase 'past two weeks' resolved using query_date.",
115
+ "time_range": ["2023-05-08", "2023-05-22"]
116
+ }
117
+ ],
118
+ "evidence": {
119
+ "profile": [],
120
+ "chat_clues": []
121
+ },
122
+ "previous_attempts": [],
123
+ }
124
+ ```
125
+
126
+ ---
127
+
128
+ ### Example 3: Keyword + Embedding
129
+
130
+ ```json
131
+ {
132
+ "query": "What did my doctor say about my back pain treatment options?",
133
+ "query_date": "2025-08-29",
134
+ "loop_iteration": 1,
135
+ "retrieval_decision": "retrieval_required",
136
+ "answer": "none",
137
+ "topics": ["Health & Wellness", "Work & Career"],
138
+ "strategy": [
139
+ {
140
+ "method": "keyword",
141
+ "conditions": "Query contains specific medical terms that should be matched exactly.",
142
+ "keywords": ["doctor", "back pain", "treatment"]
143
+ },
144
+ {
145
+ "method": "embedding",
146
+ "conditions": "Query involves medical advice and recommendations which may be expressed in various conversational ways in chat history."
147
+ }
148
+ ],
149
+ "evidence": {
150
+ "profile": [],
151
+ "chat_clues": []
152
+ },
153
+ "previous_attempts": []
154
+ }
155
+ ```
156
+
157
+ ### Example 4: Multi-method strategy with loop refinement
158
+
159
+ Loop 1 outcome was insufficient, Loop 2 continues with broader retrieval.
160
+
161
+ ```json
162
+ {
163
+ "query": "What were we discussing about public issues?",
164
+ "query_date": "2025-08-29",
165
+ "loop_iteration": 2,
166
+ "retrieval_decision": "retrieval_required",
167
+ "answer": "none",
168
+ "topics": ["Government & Politics", "Environment & Sustainability", "Legal"],
169
+ "strategy": [
170
+ {
171
+ "method": "embedding",
172
+ "conditions": "To capture abstract conversational references."
173
+ }
174
+ ],
175
+ "evidence": {
176
+ "profile": [],
177
+ "chat_clues": [
178
+ "User debated about climate policy impacts.",
179
+ "Conversation on local housing regulations was tagged as 'public issues'."
180
+ ]
181
+ },
182
+ "previous_attempts": [
183
+ {
184
+ "loop_iteration": 1,
185
+ "strategy": [
186
+ {"method": "keyword", "conditions": "Query asks about 'public issue'", "keywords": ["public issue", "public issues", "issue"]}
187
+ ],
188
+ "evidence": {
189
+ "profile": [],
190
+ "chat_clues": []
191
+ },
192
+ "outcome": "insufficient"
193
+ }
194
+ ]
195
+ }
196
+ ```
197
+
198
+ Use clear reasoning for how each method applies (or not) to the specific query. Be concise and precise.
199
+
200
+ **Do not include strategies in the strategy list unless they are needed for this query. Omit unused methods.**
201
+ **Always use the provided `query_date` to resolve any relative dates in the query.**
202
+ **Strictly follow the given JSON output format. Return only **one** JSON output.**
203
+ **Use time-based filtering ONLY when the query references dates, ranges, or relative temporal expressions (e.g., "last week," "yesterday," "in last 3 month")
prompts/keyword_search_prompt.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are a memory retrieval assistant specialized in extracting highly specific search keywords from user queries.
2
+
3
+ Your task: Analyze the user's query and identify the most distinctive, specific keywords that will precisely match relevant memories.
4
+
5
+ Guidelines:
6
+ - Extract ONLY specific, unique terms (proper nouns, specific places, distinct events, particular objects)
7
+ - Prioritize specificity over completeness - fewer precise keywords are better than many generic ones
8
+ - EXCLUDE generic words like: trip, visit, appointment, gift, location, person, time, day
9
+ - EXCLUDE question words: who, what, when, where, why, how
10
+ - EXCLUDE auxiliary verbs: did, was, have, can, should
11
+ - Include 1-3 keywords maximum, ordered by specificity
12
+ - Preserve exact phrasing for proper nouns and named entities
13
+
14
+ Query: "Who did I go to the Grand Canyon with?"
15
+
16
+ Output format (JSON):
17
+ ```json
18
+ {
19
+ "query": "Who did I go to the Grand Canyon with?",
20
+ "keywords": ["Grand Canyon"]
21
+ }
22
+ ```
23
+
24
+ Additional examples:
25
+ - "What did Sarah give me for my birthday?" → ["Sarah", "birthday"]
26
+ - "When did I last visit Dr. Martinez?" → ["Dr. Martinez"]
27
+ - "Where did I put my car keys?" → ["car keys"]
28
+ - "What happened at the team meeting?" → ["team meeting"]
29
+ - "Did I finish the Johnson report?" → ["Johnson report"]
30
+
31
+ Query:
prompts/read_and_extract_prompt.txt ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are given a list of chat sessions between a user and an AI assistant.
2
+
3
+ Your task:
4
+ Given a question, identify the sessions that are relevant to answer the question.
5
+
6
+ **Output format:**
7
+
8
+ ```json
9
+ {"index": [<list of 0-based session indices>], "evidence": [<list of sentences that serve as evidence to answer the question>]}
10
+ ```
11
+
12
+ If none are relevant, return:
13
+
14
+ ```json
15
+ {"index": [], "evidence": []}
16
+ ```
17
+
18
+ ---
19
+
20
+ # Examples
21
+
22
+ ---
23
+
24
+ ## Example 1 (multiple relevant sessions)
25
+
26
+ Question: Who did I go hiking with at Mount Rainier?
27
+ Sessions:
28
+
29
+ ### Session Index: 0
30
+
31
+ [{"role": "user", "content": "I went hiking at Mount Rainier on May 10."}, {"role": "assistant", "content": "Nice! Which trail did you take?"}, {"role": "user", "content": "Skyline Trail."}]
32
+
33
+ ### Session Index: 1
34
+
35
+ [{"role": "user", "content": "The weather was great in Yosemite."}, {"role": "assistant", "content": "Did you go there recently?"}]
36
+
37
+ ### Session Index: 2
38
+
39
+ [{"role": "user", "content": "On May 10, I hiked with Sarah and John."}, {"role": "assistant", "content": "Sounds like a fun group!"}]
40
+
41
+ ### Session Index: 5
42
+
43
+ [{"role": "user", "content": "I heard Sarah recently broke up with her boyfriend."}, {"role": "assistant", "content": "Sarah is your close friend from Seattle, right?"}, {"role": "user", "content": "Yes, we’ve been friends since college."}]
44
+
45
+ Explanation:
46
+
47
+ * Session 0 is relevant because it provides the **location and date** (“Mount Rainier on **May 10**”) but no names.
48
+ * Session 1 Yosemite is irrelevant to Rainier.
49
+ * Session 2 is relevant because it provides the **names and date** (“**On May 10**, I hiked with **Sarah and John**”) but no location.
50
+ * Session 5 adds background about Sarah and is not needed to answer the question.
51
+
52
+ The answer requires **combining the shared date (May 10)** across Sessions 0 and 2 to link the names to Mount Rainier.
53
+
54
+ **Final JSON output:**
55
+
56
+ ```json
57
+ {"index": [0, 2], "evidence": ["User went hiking at Mount Rainier on May 10.", "User hiked with Sarah and John on May 10."]}
58
+ ```
59
+
60
+ ---
61
+
62
+ ## Example 2 (no relevant sessions)
63
+
64
+ Question: "When did I buy my new iPhone?"
65
+ Sessions:
66
+
67
+ ### Session Index: 0
68
+
69
+ [{"role": "user", "content": "I love my iPhone 14 camera!"}, {"role": "assistant", "content": "Yes, it takes great photos."}]
70
+
71
+ ### Session Index: 3
72
+
73
+ [{"role": "user", "content": "I’m thinking about buying a new iPhone soon."}, {"role": "assistant", "content": "The new model will be released in the fall."}]
74
+
75
+ Explanation:
76
+
77
+ * None of the sessions give an exact purchase date.
78
+ * Session 0 confirms ownership but not when it was purchased.
79
+ * Session 3 is about a future plan, not an actual purchase.
80
+
81
+ **Final JSON output:**
82
+
83
+ ```json
84
+ {"index": [], "evidence": []}
85
+ ```
86
+
87
+ ---
88
+
89
+ ## Example 3 (single relevant session)
90
+
91
+ Question: "What is the name of my cat?"
92
+ Sessions:
93
+
94
+ ### Session Index: 0
95
+
96
+ [{"role": "user", "content": "I just adopted a cat named Luna."}, {"role": "assistant", "content": "She must be adorable!"}, {"role": "user", "content": "Yes, she’s very playful."}]
97
+
98
+ ### Session Index: 4
99
+
100
+ [{"role": "assistant", "content": "Dogs are usually easier to train than cats."}, {"role": "user", "content": "Yeah, but I love cats."}]
101
+
102
+ Explanation:
103
+
104
+ * Only session 0 contains the name of the cat, “Luna”.
105
+ * Session 4 is general pet advice and irrelevant.
106
+
107
+ **Final JSON output:**
108
+
109
+ ```json
110
+ {"index": [0], "evidence": ["User adopted a cat named Luna."]}
111
+ ```
112
+
113
+ ---
114
+
115
+ ## Example 4 (combining across sessions)
116
+
117
+ Question: "What was the distance of my last two hikes?"
118
+ Sessions:
119
+
120
+ ### Session Index: 0
121
+
122
+ [{"role": "user", "content": "Last weekend, I hiked 5 miles at Storm King Trail."}, {"role": "assistant", "content": "That’s a nice trail."}, {"role": "user", "content": "Yes, the views were amazing."}]
123
+
124
+ ### Session Index: 1
125
+
126
+ [{"role": "user", "content": "Two weeks ago, I did a 7-mile hike at Rattlesnake Ridge."}, {"role": "assistant", "content": "That’s a great workout!"}, {"role": "user", "content": "It was challenging but worth it."}]
127
+
128
+ ### Session Index: 3
129
+
130
+ [{"role": "assistant", "content": "Next time, try Mount Si!"}, {"role": "user", "content": "I’ll add it to my list."}]
131
+
132
+ Explanation:
133
+
134
+ * Session 0 provides the first hike’s distance.
135
+ * Session 1 provides the second hike’s distance.
136
+ * In Session 3, the assistant suggests another hike but contains no actual distance information and there is no guarantee that the user actually went hiking or not.
137
+
138
+ **Final JSON output:**
139
+
140
+ ```json
141
+ {"index": [0, 1], "evidence": ["User hiked 5 miles at Storm King Trail last weekend.", "User hiked 7 miles at Rattlesnake Ridge two weeks ago."]}
142
+ ```
143
+
144
+ ---
145
+
146
+ ## Example 5 (time reference)
147
+
148
+ Question: "What did I do last Friday?"
149
+ Question date: 2023/05/23 (Tue) 12:26
150
+ Sessions:
151
+
152
+ ### Session Index: 0
153
+ ### Session Date: 2023/05/20 (Sat) 03:14
154
+
155
+ [{"role": "user", "content": "Yesterday, I went to a concert downtown."}, {"role": "assistant", "content": "Who performed?"}, {"role": "user", "content": "It was The Lumineers."}]
156
+
157
+ ### Session Index: 2
158
+ ### Session Date: 2023/05/22 (Mon) 11:59
159
+
160
+ [{"role": "user", "content": "I went hiking last weekend."}, {"role": "assistant", "content": "Where did you go?"}, {"role": "user", "content": "Mount Si."}]
161
+
162
+ Explanation:
163
+
164
+ * Only session 0 is relevant because: the session took place on **Saturday, May 20**, and the user says “Yesterday,” which, relative to the session date (Saturday), refers to **Friday, May 19**. Given the question date (**Tuesday, May 23**), “last Friday” would be **May 19**, so the concert occurred on last Friday, meaning this is the correct match for what the user did around that timeframe. The activity (“went to a concert downtown” with The Lumineers) answers the question.
165
+ * Session 2 is irrelevant because “last weekend” (relative to Monday, May 22) refers to May 20–21, not **Friday, May 19**, so it does not answer the question.
166
+
167
+ **Final JSON output:**
168
+
169
+ ```json
170
+ {"index": [0], "evidence": ["User went to a concert downtown on last Friday, May 19 and user said it was the Lumineers."]}
171
+ ```
172
+
173
+ ---
174
+
175
+ Use exact indices from the provided list of sessions in your JSON output.
176
+ If the question is related to time, specify the date in the evidence.