| |
|
| | """
|
| | generate_prompts_v8_batch_fixed.py
|
| |
|
| | - Uses batch retrieval for Context, QA, and Relationships
|
| | - Saves in batches with checkpointing
|
| | - Pads contexts and QA to fixed sizes
|
| | - Appends metadata at the end
|
| | """
|
| |
|
| | import os, json, torch, numpy as np
|
| | from pathlib import Path
|
| | from tqdm import tqdm
|
| | from sentence_transformers import SentenceTransformer
|
| | from concurrent.futures import ThreadPoolExecutor
|
| |
|
| | from context_retreiver import retriever as context_retriever
|
| | from qa_retreiver import search_topk as qa_retreiver
|
| | from relationships_retreiver import batch_relationships
|
| |
|
| | QA_FILE = Path("got_all_qa_final.json")
|
| | OUT_DIR = Path("prompts_out")
|
| | CHECKPOINT_FILE = OUT_DIR / "checkpoint.json"
|
| | SAVE_BATCH_SIZE = 512
|
| | EMBED_BATCH_SIZE = 32
|
| |
|
| | DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| | print(f"[INFO] Using device: {DEVICE}")
|
| |
|
| | EMBED_MODEL = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=DEVICE)
|
| |
|
| | STRUCTURAL_TOKENS = [
|
| | "<|CTX_QA|>", "<|/CTX_QA|>",
|
| | "<|CTX_REL|>", "<|/CTX_REL|>",
|
| | "<|INSTR|>", "<|/INSTR|>",
|
| | "<|QUESTION|>", "<|/QUESTION|>",
|
| | "<|ANSWER|>", "<|/ANSWER|>",
|
| | "<|QA_SIM_1|>", "<|/QA_SIM_1|>",
|
| | "<|QA_SIM_2|>", "<|/QA_SIM_2|>",
|
| | "<|QA_SIM_3|>", "<|/QA_SIM_3|>",
|
| | "<|QA_SIM_4|>", "<|/QA_SIM_4|>",
|
| | "<|QA_SIM_5|>", "<|/QA_SIM_5|>"
|
| | ]
|
| |
|
| | def read_checkpoint():
|
| | if CHECKPOINT_FILE.exists():
|
| | try:
|
| | return int(json.loads(CHECKPOINT_FILE.read_text())["next_index"])
|
| | except:
|
| | return 0
|
| | return 0
|
| |
|
| | def write_checkpoint(idx):
|
| | OUT_DIR.mkdir(parents=True, exist_ok=True)
|
| | CHECKPOINT_FILE.write_text(json.dumps({"next_index": idx}))
|
| |
|
| | def metadata_to_str(meta):
|
| | if not meta: return ""
|
| | return "; ".join(f"{k}={v}" for k,v in meta.items() if isinstance(v,(str,int,float,bool)))
|
| |
|
| | def append_metadata_at_end(answer, context1_text, context1_meta):
|
| | parts=[]
|
| | if answer: parts.append(answer.strip())
|
| | if context1_text: parts.append(f"[Context1: {context1_text.strip()}]")
|
| | meta_str = metadata_to_str(context1_meta)
|
| | if meta_str: parts.append(f"(meta: {meta_str})")
|
| | return " ".join(parts)
|
| |
|
| | def build_prompt(ctx_texts, rel_text, sim_qas, question):
|
| | parts=[]
|
| |
|
| | for ctx in ctx_texts:
|
| | if ctx: parts.append(f"<|CTX_QA|> {ctx} <|/CTX_QA|>")
|
| | if rel_text: parts.append(f"<|CTX_REL|> {rel_text} <|/CTX_REL|>")
|
| | for i in range(5):
|
| | if i < len(sim_qas):
|
| | qa = sim_qas[i]
|
| | parts.append(f"<|QA_SIM_{i+1}|> Q: {qa['question']} A: {qa['answer']} <|/QA_SIM_{i+1}|>")
|
| | else:
|
| | parts.append(f"<|QA_SIM_{i+1}|> <|/QA_SIM_{i+1}|>")
|
| | parts.append("<|INSTR|> Use above contexts to answer concisely. <|/INSTR|>")
|
| | parts.append(f"<|QUESTION|> {question} <|/QUESTION|>")
|
| | parts.append("<|ANSWER|>")
|
| | return "\n\n".join(parts)
|
| |
|
| | def retrieve_contexts(questions, top_k=3):
|
| | """Batch retrieve context texts + metadata"""
|
| | batch_res = context_retriever.batch_retrieve(questions, top_k=top_k)
|
| | contexts=[]
|
| | for res_list in batch_res:
|
| | ctx_texts = [r["text"] for r in res_list[:top_k]]
|
| | ctx_metas = [r["metadata"] for r in res_list[:top_k]]
|
| |
|
| | while len(ctx_texts)<top_k: ctx_texts.append(""); ctx_metas.append({})
|
| | contexts.append((ctx_texts, ctx_metas))
|
| | return contexts
|
| |
|
| | def retrieve_qas_and_rels(questions, max_workers=20):
|
| | """Threaded retrieval of QA and relationships"""
|
| | sim_qas_list=[]
|
| | rel_list=[]
|
| | with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
| | sim_qas_list = list(ex.map(lambda q: qa_retreiver([q], k=5), questions))
|
| | rel_list = list(ex.map(lambda q: batch_relationships([q], top_k=1)[0], questions))
|
| | return sim_qas_list, rel_list
|
| |
|
| | def main():
|
| | OUT_DIR.mkdir(parents=True, exist_ok=True)
|
| | with open(QA_FILE,'r',encoding='utf-8') as f:
|
| | qas = json.load(f)
|
| | total = len(qas)
|
| | start_idx = read_checkpoint()
|
| | if start_idx >= total:
|
| | print("[INFO] Checkpoint beyond dataset length.")
|
| | return
|
| |
|
| | prompts_accum=[]
|
| | batch_count=start_idx//SAVE_BATCH_SIZE
|
| |
|
| | for batch_start in tqdm(range(start_idx, total, EMBED_BATCH_SIZE)):
|
| | batch_end = min(batch_start + EMBED_BATCH_SIZE, total)
|
| | batch_items = qas[batch_start:batch_end]
|
| | questions = [it.get("question") or it.get("q") or it.get("Question") for it in batch_items]
|
| | orig_answers = [it.get("answer") or it.get("a") or it.get("Answer","") for it in batch_items]
|
| |
|
| |
|
| | contexts = retrieve_contexts(questions, top_k=3)
|
| |
|
| | sim_qas_list, rel_list = retrieve_qas_and_rels(questions)
|
| |
|
| | for i,q in enumerate(questions):
|
| | if not q:
|
| | write_checkpoint(batch_start+i+1)
|
| | continue
|
| | ctx_texts, ctx_metas = contexts[i]
|
| | context1, context2, context3 = ctx_texts
|
| | meta1 = ctx_metas[0]
|
| | prompt_text = build_prompt([context2, context3], rel_list[i], sim_qas_list[i], q)
|
| | gold = append_metadata_at_end(orig_answers[i], context1, meta1)
|
| |
|
| | obj={
|
| | "id": batch_start+i,
|
| | "question": q,
|
| | "prompt": prompt_text,
|
| | "gold_answer": gold,
|
| | "context1": context1,
|
| | "retrieved_qas": sim_qas_list[i],
|
| | "relation_text": rel_list[i]
|
| | }
|
| | prompts_accum.append(obj)
|
| |
|
| |
|
| | if len(prompts_accum)>=SAVE_BATCH_SIZE:
|
| | out_path = OUT_DIR/f"prompts_batch_{batch_count:03d}.json"
|
| | out_path.write_text(json.dumps(prompts_accum, ensure_ascii=False, indent=2),encoding='utf-8')
|
| | batch_count+=1
|
| | prompts_accum=[]
|
| |
|
| | write_checkpoint(batch_start+i+1)
|
| |
|
| |
|
| | if prompts_accum:
|
| | out_path = OUT_DIR/f"prompts_batch_{batch_count:03d}.json"
|
| | out_path.write_text(json.dumps(prompts_accum, ensure_ascii=False, indent=2))
|
| |
|
| | OUT_DIR.joinpath("special_tokens_used.txt").write_text("\n".join(STRUCTURAL_TOKENS))
|
| | print("[DONE] All prompts processed.")
|
| |
|
| | if __name__=="__main__":
|
| | main()
|
| |
|