Caden Shokat
commited on
Commit
·
14c03a7
1
Parent(s):
49a5af2
processing/chunking updates
Browse files- src/processing/generate_qas.py +72 -39
- src/processing/load_chunks.py +8 -4
- src/processing/output.jsonl +0 -0
src/processing/generate_qas.py
CHANGED
|
@@ -1,70 +1,103 @@
|
|
| 1 |
-
import os, json, glob, time
|
| 2 |
from typing import List, Dict
|
| 3 |
from dotenv import load_dotenv
|
|
|
|
| 4 |
from load_chunks import load_all_chunks
|
| 5 |
from openai import OpenAI
|
| 6 |
|
| 7 |
openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 8 |
NUM_QUESTIONS = 4
|
| 9 |
SLEEP = 5
|
| 10 |
-
|
| 11 |
|
| 12 |
def make_prompt(chunk_text: str) -> str:
|
| 13 |
return f"""
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
- id: integer question number (1..{NUM_QUESTIONS})
|
| 19 |
-
- question: the question text
|
| 20 |
|
| 21 |
-
|
| 22 |
-
\"\"\"
|
| 23 |
-
{chunk_text}
|
| 24 |
-
\"\"\"
|
| 25 |
-
"""
|
| 26 |
-
|
| 27 |
-
def generate(prompt: str) -> List[Dict]:
|
| 28 |
while True:
|
| 29 |
try:
|
| 30 |
resp = openai.chat.completions.create(
|
| 31 |
-
model=
|
| 32 |
messages=[
|
| 33 |
-
{ "role": "system", "content":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
{ "role": "user", "content": prompt }
|
| 35 |
],
|
| 36 |
-
temperature=0.
|
| 37 |
-
max_tokens=NUM_QUESTIONS *
|
| 38 |
)
|
| 39 |
text = resp.choices[0].message.content.strip()
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
except json.JSONDecodeError as e:
|
| 42 |
-
print("
|
| 43 |
time.sleep(1)
|
| 44 |
|
| 45 |
def main():
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
"section": chunk.get("section"),
|
| 56 |
-
"chunk_id": chunk["chunk_id"],
|
| 57 |
-
"question_id": qa["id"],
|
| 58 |
-
"question": qa["question"]
|
| 59 |
-
})
|
| 60 |
|
| 61 |
-
|
|
|
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
|
| 68 |
|
| 69 |
if __name__ == "__main__":
|
| 70 |
main()
|
|
|
|
| 1 |
+
import os, json, glob, time, re
|
| 2 |
from typing import List, Dict
|
| 3 |
from dotenv import load_dotenv
|
| 4 |
+
import argparse
|
| 5 |
from load_chunks import load_all_chunks
|
| 6 |
from openai import OpenAI
|
| 7 |
|
| 8 |
openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 9 |
NUM_QUESTIONS = 4
|
| 10 |
SLEEP = 5
|
| 11 |
+
model = "gpt-4o-mini"
|
| 12 |
|
| 13 |
def make_prompt(chunk_text: str) -> str:
|
| 14 |
return f"""
|
| 15 |
+
Generate according to the above rules. Return **only** json. **All** string fields must be valid JSON strings wrapped in double quotes.
|
| 16 |
+
|
| 17 |
+
Here is the text chunk:\n\n\"\"\"\n{chunk_text}\n\"\"\"\n\n
|
| 18 |
+
"""
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
def generate(model: str, prompt: str) -> List[Dict]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
while True:
|
| 22 |
try:
|
| 23 |
resp = openai.chat.completions.create(
|
| 24 |
+
model=model,
|
| 25 |
messages=[
|
| 26 |
+
{ "role": "system", "content":
|
| 27 |
+
"""
|
| 28 |
+
You are an expert at generating reading-comprehension questions in **strict JSON** form.
|
| 29 |
+
Given the user’s chunk, you will output **only** a JSON array of objects—no commentary, no extra text.
|
| 30 |
+
Each object must have:
|
| 31 |
+
- question : the question text
|
| 32 |
+
- answer_span : the exact sentence from the chunk that answers this question
|
| 33 |
+
Output exactly 4 questions.
|
| 34 |
+
"""
|
| 35 |
+
},
|
| 36 |
{ "role": "user", "content": prompt }
|
| 37 |
],
|
| 38 |
+
temperature=0.2,
|
| 39 |
+
max_tokens=NUM_QUESTIONS * 100
|
| 40 |
)
|
| 41 |
text = resp.choices[0].message.content.strip()
|
| 42 |
+
print(text)
|
| 43 |
+
raw = text
|
| 44 |
+
raw = re.sub(r"^```(?:json)?\s*", "", raw)
|
| 45 |
+
raw = re.sub(r"```$", "", raw).strip()
|
| 46 |
+
|
| 47 |
+
m = re.search(r"\[.*\]", raw, flags=re.S)
|
| 48 |
+
if m:
|
| 49 |
+
raw = m.group(0)
|
| 50 |
+
|
| 51 |
+
arr = json.loads(raw)
|
| 52 |
+
|
| 53 |
+
print(arr)
|
| 54 |
+
return arr
|
| 55 |
except json.JSONDecodeError as e:
|
| 56 |
+
print("Failed to parse JSON, retrying...", e)
|
| 57 |
time.sleep(1)
|
| 58 |
|
| 59 |
def main():
|
| 60 |
+
parser = argparse.ArgumentParser(
|
| 61 |
+
description="Generate QA pairs from chunk JSON files via GPT-4"
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument("chunks_glob",
|
| 64 |
+
help="Glob pattern for chunk JSON files (e.g. 'chunks/**/*.json')")
|
| 65 |
+
parser.add_argument("output",
|
| 66 |
+
help="Output JSONL file for QA pairs")
|
| 67 |
+
parser.add_argument("--model", default=model,
|
| 68 |
+
help="OpenAI model to use (default: gpt-4)")
|
| 69 |
+
parser.add_argument("--sleep", type=float, default=0.5,
|
| 70 |
+
help="Seconds to sleep between requests (default: 0.5)")
|
| 71 |
+
args = parser.parse_args()
|
| 72 |
|
| 73 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
| 74 |
+
if not openai.api_key:
|
| 75 |
+
parser.error("Please set OPENAI_API_KEY environment variable")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
+
chunks = load_all_chunks(args.chunks_glob)
|
| 78 |
+
print(f"Loaded {len(chunks)} chunks.")
|
| 79 |
|
| 80 |
+
with open(args.output, "w", encoding="utf-8") as out_f:
|
| 81 |
+
total = 0
|
| 82 |
+
for rec in chunks:
|
| 83 |
+
qas = generate(args.model, make_prompt(rec["text"]))
|
| 84 |
+
i = 0
|
| 85 |
+
for qa in qas:
|
| 86 |
+
i += 1
|
| 87 |
+
out = {
|
| 88 |
+
"global_id": total,
|
| 89 |
+
"doc_id": rec["doc_id"],
|
| 90 |
+
"chunk_id": rec["chunk_id"],
|
| 91 |
+
"question_id": i,
|
| 92 |
+
"question": qa["question"],
|
| 93 |
+
"answer_span": qa["answer_span"],
|
| 94 |
+
"chunk": rec.get('text')
|
| 95 |
+
}
|
| 96 |
+
out_f.write(json.dumps(out, ensure_ascii=False) + "\n")
|
| 97 |
+
total += 1
|
| 98 |
+
time.sleep(args.sleep)
|
| 99 |
|
| 100 |
+
print(f"Done — generated {total} questions across {len(chunks)} chunks into '{args.output}'.")
|
| 101 |
|
| 102 |
if __name__ == "__main__":
|
| 103 |
main()
|
src/processing/load_chunks.py
CHANGED
|
@@ -2,8 +2,12 @@ from typing import List, Dict
|
|
| 2 |
import json, glob
|
| 3 |
|
| 4 |
def load_all_chunks(glob_pattern: str) -> List[Dict]:
|
| 5 |
-
|
| 6 |
for path in glob.glob(glob_pattern, recursive=True):
|
| 7 |
-
data = json.load(open(path, encoding=
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import json, glob
|
| 3 |
|
| 4 |
def load_all_chunks(glob_pattern: str) -> List[Dict]:
|
| 5 |
+
chunks = []
|
| 6 |
for path in glob.glob(glob_pattern, recursive=True):
|
| 7 |
+
data = json.load(open(path, 'r', encoding='utf-8'))
|
| 8 |
+
for rec in data:
|
| 9 |
+
if "doc_id" not in rec or "chunk_id" not in rec or "text" not in rec:
|
| 10 |
+
raise ValueError(f"Missing required keys in chunk record: {rec}")
|
| 11 |
+
chunks.append(rec)
|
| 12 |
+
return chunks
|
| 13 |
+
|
src/processing/output.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|