|
import json
|
|
import random
|
|
from tqdm import tqdm
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
MAX_TOKENS = 27
|
|
NUM_SAMPLES = 50000
|
|
SAVE_PATH = "./customgens/mini_qna_dataset.jsonl"
|
|
|
|
|
|
TEMPLATES = [
|
|
|
|
("Why do {subject} {action}?", "Because {reason}."),
|
|
("What makes {subject} {action}?", "It's because {reason}."),
|
|
("Explain why {subject} {action}.", "{reason} is the reason."),
|
|
|
|
|
|
("What is {thing}?", "{thing} is {definition}."),
|
|
("Define {thing}.", "{thing} refers to {definition}."),
|
|
("Can you tell me what {thing} means?", "Sure! It's {definition}."),
|
|
|
|
|
|
("How does {thing} work?", "It works by {mechanism}."),
|
|
("What's the mechanism behind {thing}?", "It involves {mechanism}."),
|
|
("Explain how {thing} functions.", "{mechanism} is how it works."),
|
|
|
|
|
|
("What happens when {condition}?", "{result}."),
|
|
("Describe what occurs if {condition}.", "Usually, {result}."),
|
|
("When {condition}, what takes place?", "The result is {result}."),
|
|
|
|
|
|
("Why is {thing} important?", "Because {importance}."),
|
|
("What makes {thing} important?", "{importance} is why."),
|
|
("Is {thing} important? Why?", "Yes, because {importance}."),
|
|
]
|
|
|
|
|
|
DATA = {
|
|
"animals": {
|
|
"subjects": ["cats", "dogs", "birds", "fish"],
|
|
"actions": ["sleep a lot", "bark", "fly", "swim"],
|
|
"reasons": [
|
|
"they conserve energy",
|
|
"they are nocturnal",
|
|
"it's in their nature",
|
|
"they communicate that way"
|
|
]
|
|
},
|
|
"science": {
|
|
"things": ["gravity", "photosynthesis", "a star", "an atom"],
|
|
"definitions": [
|
|
"a force that pulls objects together",
|
|
"the process plants use to make food",
|
|
"a burning ball of gas",
|
|
"the smallest unit of matter"
|
|
],
|
|
"mechanisms": [
|
|
"converting sunlight into energy",
|
|
"attracting objects with mass",
|
|
"splitting light into colors",
|
|
"colliding particles"
|
|
],
|
|
"conditions": ["you heat ice", "a star dies"],
|
|
"results": ["it melts", "it becomes a black hole"],
|
|
"importance": [
|
|
"it keeps us on Earth",
|
|
"it enables life on Earth"
|
|
]
|
|
},
|
|
"food": {
|
|
"things": ["a waffle", "chocolate", "rice", "milk"],
|
|
"definitions": [
|
|
"a sweet, crispy batter cake",
|
|
"a sweet made from cocoa",
|
|
"a grain eaten daily in Asia",
|
|
"a white liquid from cows"
|
|
],
|
|
"importance": [
|
|
"it provides energy",
|
|
"it’s part of daily nutrition"
|
|
]
|
|
}
|
|
}
|
|
|
|
TOPIC_COUNT = {k: 0 for k in DATA}
|
|
MAX_PER_TOPIC = NUM_SAMPLES // len(DATA)
|
|
|
|
def sample_topic():
|
|
options = [t for t in DATA if TOPIC_COUNT[t] < MAX_PER_TOPIC]
|
|
return random.choice(options) if options else None
|
|
|
|
def fill_template(template_pair, topic_data):
|
|
q_temp, a_temp = template_pair
|
|
replacements = {
|
|
"{subject}": random.choice(topic_data.get("subjects", topic_data.get("things", ["something"]))),
|
|
"{action}": random.choice(topic_data.get("actions", ["do things"])),
|
|
"{reason}": random.choice(topic_data.get("reasons", ["that’s how they survive"])),
|
|
"{thing}": random.choice(topic_data.get("things", ["a thing"])),
|
|
"{definition}": random.choice(topic_data.get("definitions", ["an object used every day"])),
|
|
"{mechanism}": random.choice(topic_data.get("mechanisms", ["processing energy"])),
|
|
"{condition}": random.choice(topic_data.get("conditions", ["a change occurs"])),
|
|
"{result}": random.choice(topic_data.get("results", ["it transforms"])),
|
|
"{importance}": random.choice(topic_data.get("importance", ["it is vital to survival"]))
|
|
}
|
|
|
|
q = q_temp
|
|
a = a_temp
|
|
for key, val in replacements.items():
|
|
q = q.replace(key, val)
|
|
a = a.replace(key, val)
|
|
return q.strip(), a.strip()
|
|
|
|
def maybe_add_noise(q, a):
|
|
rand = random.random()
|
|
if rand < 0.05:
|
|
a = "I'm not sure."
|
|
elif rand < 0.10:
|
|
q += " Just wondering."
|
|
a = "Well, " + a
|
|
return q, a
|
|
|
|
def token_count(text):
|
|
return len(tokenizer.encode(text))
|
|
|
|
def main():
|
|
with open(SAVE_PATH, "w", encoding="utf-8") as f:
|
|
total = 0
|
|
pbar = tqdm(total=NUM_SAMPLES)
|
|
|
|
while total < NUM_SAMPLES:
|
|
topic = sample_topic()
|
|
if not topic:
|
|
break
|
|
template = random.choice(TEMPLATES)
|
|
topic_data = DATA[topic]
|
|
|
|
question, answer = fill_template(template, topic_data)
|
|
question, answer = maybe_add_noise(question, answer)
|
|
|
|
combined = f"Q: {question} A: {answer}"
|
|
if token_count(combined) <= MAX_TOKENS:
|
|
record = {
|
|
"question": question,
|
|
"answer": answer,
|
|
"text": combined
|
|
}
|
|
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
|
total += 1
|
|
TOPIC_COUNT[topic] += 1
|
|
pbar.update(1)
|
|
|
|
if total % 5000 == 0:
|
|
print(f"\n[Sample {total}]")
|
|
print("Q:", question)
|
|
print("A:", answer)
|
|
print("Tokens:", token_count(combined))
|
|
|
|
pbar.close()
|
|
print(f"\n✅ Saved {total} samples to {SAVE_PATH}")
|
|
|
|
if __name__ == "__main__":
|
|
main() |