LLM / dataset.py
dushuai112233's picture
Upload 5 files
47f89ab verified
import json
import pandas as pd
data_path = [
"./data/Chinese-medical-dialogue-data-master/Data_数据/IM_内科/内科5000-33000.csv",
"./data/Chinese-medical-dialogue-data-master/Data_数据/Oncology_肿瘤科/肿瘤科5-10000.csv",
"./data/Chinese-medical-dialogue-data-master/Data_数据/Pediatric_儿科/儿科5-14000.csv",
"./data/Chinese-medical-dialogue-data-master/Data_数据/Surgical_外科/外科5-14000.csv",
]
train_json_path = "./data/train.json"
val_json_path = "./data/val.json"
# 每个数据取 10000 条作为训练
train_size = 10000
# 每个数据取 2000 条作为验证
val_size = 2000
def main():
train_f = open(train_json_path, "a", encoding='utf-8')
val_f = open(val_json_path, "a", encoding='utf-8')
for path in data_path:
data = pd.read_csv(path, encoding='ANSI')
train_count = 0
val_count = 0
for index, row in data.iterrows():
question = row["ask"]
answer = row["answer"]
line = {
"question": question,
"answer": answer
}
line = json.dumps(line, ensure_ascii=False)
if train_count < train_size:
train_f.write(line + "\n")
train_count = train_count + 1
elif val_count < val_size:
val_f.write(line + "\n")
val_count = val_count + 1
else:
break
print("数据处理完毕!")
train_f.close()
val_f.close()
if __name__ == '__main__':
main()