Spaces:
Runtime error
Runtime error
import csv | |
import random | |
import json | |
import numpy as np | |
from sklearn.model_selection import ShuffleSplit | |
samples = { | |
"x": [], | |
"y": [], | |
} | |
little = False | |
all_loaded_sample = 500000 | |
# 二十万条 | |
with open("./data/prompts.csv") as f: | |
csv_reader = csv.DictReader(f) | |
for row_number, row in enumerate(csv_reader): | |
# if row_number == random.randint(0, 1000): | |
# break | |
if little: | |
if row_number > 100: | |
break | |
if row_number > all_loaded_sample: | |
break | |
datum = row | |
modifiers = json.loads(datum['raw_data'])['modifiers'] | |
n = random.randint(1, 11) | |
if len(modifiers) < 3: | |
continue | |
label = ",".join(modifiers) if len(modifiers) > 1 else modifiers[0] | |
if 0<n and n<=6: | |
x = modifiers[0] | |
elif n>6 and n<=9: | |
x = ",".join(modifiers[:2]) | |
else: | |
x = ",".join(modifiers[:3]) | |
# 小文本到大文本,因此x更小,同时x按照6:3:1的比例分配 | |
samples["x"].append(x) | |
samples["y"].append(label) | |
with open("./data/dataset_openprompt.json", "w") as f: | |
json.dump(samples, f, indent=4, ensure_ascii=False) | |
print("*"*40, "save train done.", "with little" if little else "", "*"*40) | |