MiniGPT / datasetgen-synthetic.py
CreatedNull's picture
Upload folder using huggingface_hub
79eec1d verified
import random
import json
topics = {
"Math Reasoning": [
("What is {a} + {b}?", "{a} + {b} is {sum}."),
("If you have {a} apples and get {b} more, how many?", "{a} + {b} = {sum} apples."),
("Solve: {a} + {b}", "The answer is {sum}.")
],
"Causality": [
("If it rains, what might happen?", "If it rains, the ground may become wet."),
("Why do plants grow towards light?", "Because light is a stimulus and plants respond by growing towards it."),
("What happens if you drop a glass?", "It will likely break due to gravity.")
],
"Grammar Correction": [
("Correct this: 'He go to school everyday.'", "'He goes to school every day.'"),
("Fix this sentence: 'I has two cat.'", "'I have two cats.'"),
("Can you fix this sentence: 'he have two taco.'", "'He has two tacos.'"),
("What’s the correct form of: 'She don't like it.'", "'She doesn't like it.'")
],
"Common Sense": [
("Can a person eat soup with a fork?", "No, it is impractical to eat soup with a fork."),
("Should you touch fire?", "No, touching fire can cause burns."),
("If you're tired, what should you do?", "You should rest or sleep.")
],
"World Knowledge": [
("What is the capital of France?", "Paris is the capital of France."),
("Who was the first president of the USA?", "George Washington."),
("What currency is used in Japan?", "The Japanese Yen.")
],
"Instruction Following": [
("Open the window and turn off the light.", "Opening the window. Turning off the light."),
("Sort these numbers in ascending order: 5, 2, 8.", "2, 5, 8."),
("Sort these numbers in descending order: 5, 2, 8.", "8, 5, 2."),
("Describe how to make a sandwich.", "Take two slices of bread, add your fillings, and place one slice on top.")
]
}
def generate_sample(id, topic):
pattern = random.choice(topics[topic])
if topic == "Math Reasoning":
a = random.randint(1, 20)
b = random.randint(1, 20)
sum_ab = a + b
input_str = pattern[0].format(a=a, b=b, sum=sum_ab)
output_str = pattern[1].format(a=a, b=b, sum=sum_ab)
else:
input_str = pattern[0]
output_str = pattern[1]
return {
"id": id,
"topic": topic,
"input": input_str,
"output": output_str
}
def generate_dataset(n=10000):
dataset = []
topic_list = list(topics.keys())
for i in range(n):
topic = random.choice(topic_list)
sample = generate_sample(i, topic)
dataset.append(sample)
return dataset
def save_as_jsonl(data, path="./data/reasoned_data.jsonl"):
with open(path, "w", encoding="utf-8") as f:
for item in data:
json.dump(item, f, ensure_ascii=False)
f.write("\n")
if __name__ == "__main__":
data = generate_dataset(10000)
save_as_jsonl(data)
print("Saved to ./data/reasoned_data.jsonl")