File size: 5,785 Bytes
79eec1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import json
import random
from tqdm import tqdm
from transformers import AutoTokenizer
# CONFIG
tokenizer = AutoTokenizer.from_pretrained("gpt2")
MAX_TOKENS = 27
NUM_SAMPLES = 50000
SAVE_PATH = "./customgens/mini_qna_dataset.jsonl"
# Extended Templates with Paraphrasing
TEMPLATES = [
# WHY
("Why do {subject} {action}?", "Because {reason}."),
("What makes {subject} {action}?", "It's because {reason}."),
("Explain why {subject} {action}.", "{reason} is the reason."),
# WHAT IS
("What is {thing}?", "{thing} is {definition}."),
("Define {thing}.", "{thing} refers to {definition}."),
("Can you tell me what {thing} means?", "Sure! It's {definition}."),
# HOW
("How does {thing} work?", "It works by {mechanism}."),
("What's the mechanism behind {thing}?", "It involves {mechanism}."),
("Explain how {thing} functions.", "{mechanism} is how it works."),
# WHEN / CONDITION
("What happens when {condition}?", "{result}."),
("Describe what occurs if {condition}.", "Usually, {result}."),
("When {condition}, what takes place?", "The result is {result}."),
# IMPORTANCE
("Why is {thing} important?", "Because {importance}."),
("What makes {thing} important?", "{importance} is why."),
("Is {thing} important? Why?", "Yes, because {importance}."),
]
# Knowledge Bank
DATA = {
"animals": {
"subjects": ["cats", "dogs", "birds", "fish"],
"actions": ["sleep a lot", "bark", "fly", "swim"],
"reasons": [
"they conserve energy",
"they are nocturnal",
"it's in their nature",
"they communicate that way"
]
},
"science": {
"things": ["gravity", "photosynthesis", "a star", "an atom"],
"definitions": [
"a force that pulls objects together",
"the process plants use to make food",
"a burning ball of gas",
"the smallest unit of matter"
],
"mechanisms": [
"converting sunlight into energy",
"attracting objects with mass",
"splitting light into colors",
"colliding particles"
],
"conditions": ["you heat ice", "a star dies"],
"results": ["it melts", "it becomes a black hole"],
"importance": [
"it keeps us on Earth",
"it enables life on Earth"
]
},
"food": {
"things": ["a waffle", "chocolate", "rice", "milk"],
"definitions": [
"a sweet, crispy batter cake",
"a sweet made from cocoa",
"a grain eaten daily in Asia",
"a white liquid from cows"
],
"importance": [
"it provides energy",
"it’s part of daily nutrition"
]
}
}
TOPIC_COUNT = {k: 0 for k in DATA}
MAX_PER_TOPIC = NUM_SAMPLES // len(DATA)
def sample_topic():
options = [t for t in DATA if TOPIC_COUNT[t] < MAX_PER_TOPIC]
return random.choice(options) if options else None
def fill_template(template_pair, topic_data):
q_temp, a_temp = template_pair
replacements = {
"{subject}": random.choice(topic_data.get("subjects", topic_data.get("things", ["something"]))),
"{action}": random.choice(topic_data.get("actions", ["do things"])),
"{reason}": random.choice(topic_data.get("reasons", ["that’s how they survive"])),
"{thing}": random.choice(topic_data.get("things", ["a thing"])),
"{definition}": random.choice(topic_data.get("definitions", ["an object used every day"])),
"{mechanism}": random.choice(topic_data.get("mechanisms", ["processing energy"])),
"{condition}": random.choice(topic_data.get("conditions", ["a change occurs"])),
"{result}": random.choice(topic_data.get("results", ["it transforms"])),
"{importance}": random.choice(topic_data.get("importance", ["it is vital to survival"]))
}
q = q_temp
a = a_temp
for key, val in replacements.items():
q = q.replace(key, val)
a = a.replace(key, val)
return q.strip(), a.strip()
def maybe_add_noise(q, a):
rand = random.random()
if rand < 0.05:
a = "I'm not sure."
elif rand < 0.10:
q += " Just wondering."
a = "Well, " + a
return q, a
def token_count(text):
return len(tokenizer.encode(text))
def main():
with open(SAVE_PATH, "w", encoding="utf-8") as f:
total = 0
pbar = tqdm(total=NUM_SAMPLES)
while total < NUM_SAMPLES:
topic = sample_topic()
if not topic:
break
template = random.choice(TEMPLATES)
topic_data = DATA[topic]
question, answer = fill_template(template, topic_data)
question, answer = maybe_add_noise(question, answer)
combined = f"Q: {question} A: {answer}"
if token_count(combined) <= MAX_TOKENS:
record = {
"question": question,
"answer": answer,
"text": combined
}
f.write(json.dumps(record, ensure_ascii=False) + "\n")
total += 1
TOPIC_COUNT[topic] += 1
pbar.update(1)
if total % 5000 == 0:
print(f"\n[Sample {total}]")
print("Q:", question)
print("A:", answer)
print("Tokens:", token_count(combined))
pbar.close()
print(f"\n✅ Saved {total} samples to {SAVE_PATH}")
if __name__ == "__main__":
main() |