| |
| """Build v5 retrieval caches in the JSONL format consumed by main.py. |
| |
| The original README refers to retrieval/run_retrieval.py, but that module is |
| not present in this checkout. This script reconstructs the two cache formats |
| used by main.py: |
| |
| * turn-level episodic cache: |
| response_cache/retrieval/flat-gte/<name>_retrievallog_turn_flat-gte |
| * session-level semantic cache: |
| response_cache/retrieval/semantic-gte/<name>_retrievallog_semantic_flat-gte |
| |
| It embeds the unique corpus once, then ranks each question over its own |
| haystack sessions. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import os |
| from collections import Counter, defaultdict |
| from pathlib import Path |
| from typing import Any, Dict, Iterable, List, Sequence, Tuple |
|
|
| import numpy as np |
|
|
|
|
| KS = (1, 3, 5, 10, 30, 50, 100) |
|
|
|
|
| def clean_text(text: Any) -> str: |
| """Normalize text so tokenizer UTF-8 encoding cannot fail on bad surrogates.""" |
| text = "" if text is None else str(text) |
| text = text.encode("utf-8", errors="replace").decode("utf-8", errors="replace") |
| return text.replace("\ufffd", " ") |
|
|
|
|
| def read_json(path: str | Path) -> Any: |
| with open(path, "r", encoding="utf-8") as f: |
| return json.load(f) |
|
|
|
|
| def write_jsonl(path: str | Path, rows: Iterable[Dict[str, Any]]) -> None: |
| path = Path(path) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| tmp = path.with_suffix(path.suffix + ".tmp") |
| with open(tmp, "w", encoding="utf-8") as f: |
| for row in rows: |
| f.write(json.dumps(row, ensure_ascii=False) + "\n") |
| tmp.replace(path) |
|
|
|
|
| def ordered_unique_session_ids(entries: Sequence[Dict[str, Any]]) -> List[str]: |
| seen = set() |
| out = [] |
| for entry in entries: |
| for sid in entry["haystack_session_ids"]: |
| if sid not in seen: |
| seen.add(sid) |
| out.append(sid) |
| return out |
|
|
|
|
| def build_turn_corpus( |
| entries: Sequence[Dict[str, Any]], |
| all_sessions: Dict[str, List[Dict[str, str]]], |
| ) -> Tuple[List[Dict[str, Any]], Dict[str, List[int]]]: |
| corpus: List[Dict[str, Any]] = [] |
| sid_to_indices: Dict[str, List[int]] = defaultdict(list) |
|
|
| for sid in ordered_unique_session_ids(entries): |
| turns = all_sessions.get(sid) |
| if not turns: |
| continue |
| for turn_idx, msg in enumerate(turns, start=1): |
| text = clean_text(msg.get("content") or "") |
| if not text.strip(): |
| continue |
| item = { |
| "corpus_id": f"{sid}_{turn_idx}", |
| "sid": sid, |
| "text": text, |
| } |
| sid_to_indices[sid].append(len(corpus)) |
| corpus.append(item) |
| return corpus, sid_to_indices |
|
|
|
|
| def build_semantic_corpus( |
| entries: Sequence[Dict[str, Any]], |
| summaries: Dict[str, Dict[str, Any]], |
| facts: Dict[str, List[Dict[str, Any]]], |
| ) -> Tuple[List[Dict[str, Any]], Dict[str, List[int]]]: |
| corpus: List[Dict[str, Any]] = [] |
| sid_to_indices: Dict[str, List[int]] = defaultdict(list) |
|
|
| for sid in ordered_unique_session_ids(entries): |
| summary_text = clean_text((summaries.get(sid) or {}).get("session_summary") or "") |
| if summary_text.strip(): |
| sid_to_indices[sid].append(len(corpus)) |
| corpus.append( |
| { |
| "corpus_id": sid, |
| "sid": sid, |
| "source": "summary", |
| "text": summary_text, |
| } |
| ) |
|
|
| fact_items = facts.get(sid) or [] |
| fact_texts = [ |
| text |
| for x in fact_items |
| for text in [clean_text(x.get("user-info", "")).strip()] |
| if text |
| ] |
| if fact_texts: |
| sid_to_indices[sid].append(len(corpus)) |
| corpus.append( |
| { |
| "corpus_id": sid, |
| "sid": sid, |
| "source": "facts", |
| "text": " ".join(fact_texts), |
| } |
| ) |
|
|
| return corpus, sid_to_indices |
|
|
|
|
| def format_query(question: str) -> str: |
| question = clean_text(question) |
| return ( |
| "Instruct: Given a question, retrieve relevant passages from a user's " |
| "chat history that contain information needed to answer it.\n" |
| f"Query: {question}" |
| ) |
|
|
|
|
| def shard_bounds(n_items: int, n_shards: int) -> List[Tuple[int, int]]: |
| n_shards = max(1, min(n_shards, n_items)) |
| step = math.ceil(n_items / n_shards) |
| bounds = [] |
| for start in range(0, n_items, step): |
| bounds.append((start, min(start + step, n_items))) |
| return bounds |
|
|
|
|
| def _last_token_pool(last_hidden_states, attention_mask): |
| import torch |
|
|
| left_padding = bool((attention_mask[:, -1].sum() == attention_mask.shape[0]).item()) |
| if left_padding: |
| return last_hidden_states[:, -1] |
| sequence_lengths = attention_mask.sum(dim=1) - 1 |
| batch_size = last_hidden_states.shape[0] |
| return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] |
|
|
|
|
| def _encode_worker( |
| rank: int, |
| gpu_id: int, |
| texts: List[str], |
| out_path: str, |
| model_name: str, |
| batch_size: int, |
| max_length: int, |
| dtype: str, |
| ) -> None: |
| os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) |
|
|
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoModel, AutoTokenizer |
|
|
| if Path(out_path).exists(): |
| arr = np.load(out_path, mmap_mode="r") |
| if arr.shape[0] == len(texts): |
| print(f"[worker {rank}] shard exists, skipping: {out_path}", flush=True) |
| return |
|
|
| device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| torch_dtype = { |
| "bf16": torch.bfloat16, |
| "fp16": torch.float16, |
| "fp32": torch.float32, |
| }[dtype] |
|
|
| print( |
| f"[worker {rank}] loading {model_name} on {device}; " |
| f"{len(texts)} texts", |
| flush=True, |
| ) |
| tokenizer = AutoTokenizer.from_pretrained( |
| model_name, |
| trust_remote_code=True, |
| padding_side="left", |
| use_fast=False, |
| ) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| model = AutoModel.from_pretrained( |
| model_name, |
| trust_remote_code=True, |
| torch_dtype=torch_dtype if device != "cpu" else torch.float32, |
| low_cpu_mem_usage=True, |
| ) |
| model.to(device) |
| model.eval() |
| if hasattr(model, "config"): |
| model.config.use_cache = False |
|
|
| chunks = [] |
| with torch.inference_mode(): |
| for start in range(0, len(texts), batch_size): |
| batch = [str(x) for x in texts[start : start + batch_size]] |
| encoded = tokenizer( |
| batch, |
| max_length=max_length, |
| padding=True, |
| truncation=True, |
| return_tensors="pt", |
| ) |
| encoded = {k: v.to(device) for k, v in encoded.items()} |
| outputs = model(**encoded, use_cache=False) |
| emb = _last_token_pool(outputs.last_hidden_state, encoded["attention_mask"]) |
| emb = F.normalize(emb, p=2, dim=1) |
| chunks.append(emb.detach().cpu().to(torch.float16).numpy()) |
|
|
| if rank == 0 and (start // batch_size) % 100 == 0: |
| print(f"[worker {rank}] encoded {start + len(batch)}/{len(texts)}", flush=True) |
|
|
| arr = np.concatenate(chunks, axis=0) if chunks else np.empty((0, 0), dtype=np.float16) |
| Path(out_path).parent.mkdir(parents=True, exist_ok=True) |
| np.save(out_path, arr) |
| print(f"[worker {rank}] wrote {out_path} {arr.shape}", flush=True) |
|
|
|
|
| def encode_texts( |
| texts: List[str], |
| cache_path: Path, |
| model_name: str, |
| batch_size: int, |
| max_length: int, |
| dtype: str, |
| num_gpus: int, |
| ) -> np.ndarray: |
| if cache_path.exists(): |
| print(f"[cache] loading embeddings: {cache_path}", flush=True) |
| return np.load(cache_path, mmap_mode="r") |
|
|
| import torch.multiprocessing as mp |
|
|
| tmp_dir = cache_path.parent / (cache_path.name + ".shards") |
| tmp_dir.mkdir(parents=True, exist_ok=True) |
|
|
| visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") |
| if visible: |
| gpu_ids = [int(x) for x in visible.split(",") if x.strip()] |
| else: |
| try: |
| import torch |
|
|
| gpu_ids = list(range(torch.cuda.device_count())) |
| except Exception: |
| gpu_ids = [] |
| if not gpu_ids: |
| gpu_ids = [0] |
| gpu_ids = gpu_ids[: max(1, num_gpus)] |
|
|
| bounds = shard_bounds(len(texts), len(gpu_ids)) |
| processes = [] |
| mp.set_start_method("spawn", force=True) |
| for rank, (start, end) in enumerate(bounds): |
| shard_path = tmp_dir / f"shard_{rank:02d}.npy" |
| shard_texts = texts[start:end] |
| p = mp.Process( |
| target=_encode_worker, |
| args=( |
| rank, |
| gpu_ids[rank % len(gpu_ids)], |
| shard_texts, |
| str(shard_path), |
| model_name, |
| batch_size, |
| max_length, |
| dtype, |
| ), |
| ) |
| p.start() |
| processes.append((p, shard_path)) |
|
|
| for p, shard_path in processes: |
| p.join() |
| if p.exitcode != 0: |
| raise RuntimeError(f"embedding worker failed: {p.pid} exit={p.exitcode} shard={shard_path}") |
|
|
| shards = [np.load(path, mmap_mode="r") for _, path in processes] |
| arr = np.concatenate(shards, axis=0) |
| np.save(cache_path, arr.astype(np.float16, copy=False)) |
| print(f"[cache] wrote embeddings: {cache_path} {arr.shape}", flush=True) |
| return np.load(cache_path, mmap_mode="r") |
|
|
|
|
| def session_metrics(ranked_sids: List[str], answer_sids: Sequence[str], key: str) -> Dict[str, Dict[str, float]]: |
| answer = set(answer_sids) |
| metrics: Dict[str, float] = {} |
| if not answer: |
| return {key: {f"recall_any@{k}": 0.0 for k in KS}} |
|
|
| for k in KS: |
| top = ranked_sids[:k] |
| top_set = set(top) |
| hit_count = len(top_set & answer) |
| metrics[f"recall_any@{k}"] = 1.0 if hit_count > 0 else 0.0 |
| metrics[f"recall_all@{k}"] = 1.0 if answer.issubset(top_set) else 0.0 |
|
|
| rel = [1.0 if sid in answer else 0.0 for sid in top] |
| dcg = sum(r / math.log2(i + 2) for i, r in enumerate(rel)) |
| ideal_hits = min(len(answer), k) |
| idcg = sum(1.0 / math.log2(i + 2) for i in range(ideal_hits)) |
| metrics[f"ndcg_any@{k}"] = dcg / idcg if idcg else 0.0 |
|
|
| return {key: metrics} |
|
|
|
|
| def make_ranked_rows( |
| entries: Sequence[Dict[str, Any]], |
| corpus: Sequence[Dict[str, Any]], |
| corpus_emb: np.ndarray, |
| query_emb: np.ndarray, |
| sid_to_indices: Dict[str, List[int]], |
| out_kind: str, |
| ) -> List[Dict[str, Any]]: |
| import torch |
|
|
| device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| corpus_t = torch.as_tensor(np.asarray(corpus_emb), dtype=torch.float16, device=device) |
| query_t = torch.as_tensor(np.asarray(query_emb), dtype=torch.float16, device=device) |
|
|
| rows = [] |
| for qi, entry in enumerate(entries): |
| date_lookup = dict(zip(entry["haystack_session_ids"], entry["haystack_dates"])) |
| candidate_indices: List[int] = [] |
| for sid in entry["haystack_session_ids"]: |
| candidate_indices.extend(sid_to_indices.get(sid, [])) |
|
|
| idx = torch.as_tensor(candidate_indices, dtype=torch.long, device=device) |
| scores = torch.matmul(corpus_t.index_select(0, idx), query_t[qi]) |
| order = torch.argsort(scores, descending=True).detach().cpu().numpy().tolist() |
| ranked_indices = [candidate_indices[i] for i in order] |
|
|
| ranked_items = [] |
| ranked_sids = [] |
| for ci in ranked_indices: |
| item = corpus[ci] |
| sid = item["sid"] |
| ranked_sids.append(sid) |
| out_item = { |
| "corpus_id": item["corpus_id"], |
| "text": item["text"], |
| "timestamp": date_lookup.get(sid, ""), |
| } |
| if out_kind == "semantic": |
| out_item["source"] = item["source"] |
| ranked_items.append(out_item) |
|
|
| metric_key = "turn" if out_kind == "turn" else "session" |
| rows.append( |
| { |
| "question_id": entry["question_id"], |
| "question_type": entry["question_type"], |
| "question": entry["question"], |
| "answer": entry["answer"], |
| "question_date": entry["question_date"], |
| "answer_session_ids": entry["answer_session_ids"], |
| "retrieval_results": { |
| "query": entry["question"], |
| "ranked_items": ranked_items, |
| "metrics": session_metrics(ranked_sids, entry["answer_session_ids"], metric_key), |
| }, |
| } |
| ) |
|
|
| if (qi + 1) % 25 == 0: |
| print(f"[rank {out_kind}] {qi + 1}/{len(entries)}", flush=True) |
|
|
| return rows |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--in_file", default="dataset/evolv_mem_v5.json") |
| parser.add_argument("--all_sessions_file", default="dataset/all_sessions.json") |
| parser.add_argument("--summary_file", default="dataset/all_session_summary.json") |
| parser.add_argument("--facts_file", default="dataset/all_session_user_facts.json") |
| parser.add_argument("--out_turn", default="response_cache/retrieval/flat-gte/evolv_mem_v5_retrievallog_turn_flat-gte") |
| parser.add_argument("--out_semantic", default="response_cache/retrieval/semantic-gte/evolv_mem_v5_retrievallog_semantic_flat-gte") |
| parser.add_argument("--work_dir", default="response_cache/retrieval/.tmp/evolv_mem_v5_flat-gte") |
| parser.add_argument("--model_name", default="Alibaba-NLP/gte-Qwen2-7B-instruct") |
| parser.add_argument("--batch_size", type=int, default=16) |
| parser.add_argument("--query_batch_size", type=int, default=32) |
| parser.add_argument("--max_length", type=int, default=512) |
| parser.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16") |
| parser.add_argument("--num_gpus", type=int, default=8) |
| parser.add_argument("--skip_turn", action="store_true") |
| parser.add_argument("--skip_semantic", action="store_true") |
| args = parser.parse_args() |
|
|
| work_dir = Path(args.work_dir) |
| work_dir.mkdir(parents=True, exist_ok=True) |
|
|
| entries = read_json(args.in_file) |
| all_sessions = read_json(args.all_sessions_file) |
| print(f"[data] entries={len(entries)} all_sessions={len(all_sessions)}", flush=True) |
|
|
| query_texts = [format_query(entry["question"]) for entry in entries] |
| query_emb = encode_texts( |
| query_texts, |
| work_dir / "query_embeddings.npy", |
| args.model_name, |
| args.query_batch_size, |
| args.max_length, |
| args.dtype, |
| args.num_gpus, |
| ) |
|
|
| if not args.skip_turn: |
| turn_corpus, turn_sid_to_indices = build_turn_corpus(entries, all_sessions) |
| print(f"[turn] corpus_items={len(turn_corpus)}", flush=True) |
| turn_emb = encode_texts( |
| [x["text"] for x in turn_corpus], |
| work_dir / "turn_embeddings.npy", |
| args.model_name, |
| args.batch_size, |
| args.max_length, |
| args.dtype, |
| args.num_gpus, |
| ) |
| turn_rows = make_ranked_rows( |
| entries, |
| turn_corpus, |
| turn_emb, |
| query_emb, |
| turn_sid_to_indices, |
| "turn", |
| ) |
| write_jsonl(args.out_turn, turn_rows) |
| print(f"[turn] wrote {args.out_turn}", flush=True) |
|
|
| if not args.skip_semantic: |
| summaries = read_json(args.summary_file) |
| facts = read_json(args.facts_file) |
| semantic_corpus, semantic_sid_to_indices = build_semantic_corpus(entries, summaries, facts) |
| print(f"[semantic] corpus_items={len(semantic_corpus)}", flush=True) |
| semantic_emb = encode_texts( |
| [x["text"] for x in semantic_corpus], |
| work_dir / "semantic_embeddings.npy", |
| args.model_name, |
| args.batch_size, |
| args.max_length, |
| args.dtype, |
| args.num_gpus, |
| ) |
| semantic_rows = make_ranked_rows( |
| entries, |
| semantic_corpus, |
| semantic_emb, |
| query_emb, |
| semantic_sid_to_indices, |
| "semantic", |
| ) |
| write_jsonl(args.out_semantic, semantic_rows) |
| print(f"[semantic] wrote {args.out_semantic}", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|