Spaces:
Runtime error
Runtime error
import ujson | |
import re | |
from os.path import dirname, abspath, exists, isdir | |
from os import remove, mkdir, walk | |
import time | |
from collections import defaultdict | |
from matplotlib import pyplot as plt | |
import codecs, csv | |
import pandas as pd | |
import numpy as np | |
from rich import progress | |
from rich.table import Table | |
from rich.console import Console | |
from fastparquet import ParquetFile, write | |
import pyarrow.parquet as pq | |
from opencc import OpenCC | |
import sys | |
sys.path.extend(['.','..']) | |
from logger import Logger | |
from config import PROJECT_ROOT | |
from utils.functions import get_path_of_suffix_files, DropDatasetDuplicate | |
log = Logger('data_process', save2file=True, file_name=PROJECT_ROOT + '/logs/raw_data_process.log') | |
punctuation = set("!\"#$%&'()*+,-./:;<=>?@[\]^_`{|}~.,;《》?!“”‘’@#¥%…&×()——+【】{};;●,。&~、|\s::\n") | |
en_punctuation = ",().!;:" | |
zh_punctuation = ",()。!;:" | |
def delete_file(file: str)-> bool: | |
''' | |
询问删除文件 | |
''' | |
if exists(file): | |
ans = input('delete file: {} ? Yes (y) or No (n)'.format(file)) | |
ans = ans.lower() | |
if ans in ('yes', 'y'): | |
remove(file) | |
print('deleted.') | |
return True | |
return False | |
def remove_duplicate_punctuation(sentence: str) -> str: | |
''' | |
删除句子中重复的标点符号、重复的空格,同时将换行变为特殊字符'\n' | |
''' | |
# 将空格(全角空格)替换为逗号, 可能会有重复的空客,下面删除重复标点会删除 | |
sentence = re.sub(' | ', ',', sentence) | |
ans = '' | |
n = len(sentence) | |
p = 0 | |
while p < n: | |
ans += sentence[p] | |
while p + 1 < n and sentence[p] in punctuation and sentence[p + 1] in punctuation: | |
p += 1 | |
p += 1 | |
return ans | |
def convert_en_punctuation_to_zh_punct(sentence: str) -> str: | |
''' | |
将句子中的英文标点替换文中文标点 | |
''' | |
n = len(zh_punctuation) | |
for i in range(n): | |
sentence = sentence.replace(en_punctuation[i], zh_punctuation[i]) | |
return sentence | |
def get_sentences_dice_similarity(st_a: str, st_b: str) -> float: | |
''' | |
获取两个句子的Dice相似度(Dice similarity) | |
s(a, b) = 2 * len( set(a) & set(b) ) / (len(set(a)) + len(set(b))) | |
''' | |
set_a, set_b = set(st_a), set(st_b) | |
total_len = len(set_a) + len(set_b) | |
if total_len == 0: return 0.0 | |
inter_set = set_a & set_b | |
return ( 2 * len(inter_set)) / total_len | |
def write_single_parquet_file(file_name: str, data_frame: pd.DataFrame) -> None: | |
''' | |
将dataframe写到单独的parquet file中 | |
''' | |
append = False | |
if exists(file_name): | |
append = True | |
write(file_name, data_frame, compression='GZIP',append=append) | |
def read_and_write_template(read_file: str, write_to_file: str, call_back: object, group_cnt: int=10000) -> None: | |
''' | |
处理数据读写模板,需要提供一个回调函数call_back, | |
read_file: 原始数据文件 | |
write_to_file:处理后的要保存数据文件 | |
call_back:函数输入一个字符串,输出一个处理后的字典dict,如果输入的字符串为无效数据,请返回None | |
group_cnt: parquet file分割行数 | |
如: | |
>>> def call_back(inputs: str) -> dict: | |
>>> if check(inputs) not valid: | |
>>> return None | |
... | |
... do something for inputs | |
... | |
>>> my_dict = { | |
>>> 'prompt': inputs['p'], | |
>>> 'response': inputs['a1'] + inputs['a2'], | |
>>> ... | |
>>> } | |
>>> return my_dict | |
''' | |
log.info('process file:{}'.format(read_file), save_to_file=True) | |
start = time.time() | |
raw_line_cnt = 0 | |
keep_line_cnt = 0 | |
with progress.open(read_file, 'r', encoding='utf-8') as f_read: | |
cur_rows = [] | |
append = cur_rows.append | |
for line in f_read: | |
try: | |
raw_line_cnt += 1 | |
write_dict = call_back(line) | |
if write_dict is None: continue | |
keep_line_cnt += 1 | |
append(write_dict) | |
# ujson.dump(write_obj, f_write, indent=4, ensure_ascii=False) | |
# ujson.dump(write_obj, f_write, ensure_ascii=False,) | |
# f_write.write('\n') | |
if len(cur_rows) >= group_cnt: | |
df = pd.DataFrame(cur_rows) | |
write_single_parquet_file(write_to_file, df) | |
cur_rows = [] | |
append = cur_rows.append | |
except Exception as e: | |
# log.error('处理文件异常:{}, content:{}'.format(str(e), line)) | |
print(line) | |
raise e | |
# end for | |
# 处理末尾部分 | |
if len(cur_rows) > 0: | |
df = pd.DataFrame(cur_rows) | |
write_single_parquet_file(write_to_file, df) | |
cur_rows = [] | |
end = time.time() | |
log.info('原始文件:{},共{}行,处理后剩余{}行,保存到文件:{}。耗时:{:.6}s'\ | |
.format(read_file, raw_line_cnt, keep_line_cnt, write_to_file, end - start), save_to_file=True) | |
#=====================================数据集处理================================= | |
def process_web_text(keep_start: int=5, response_less_word: int=10) -> None: | |
''' | |
处理425万社区问答webtext2019zh知识类数据集 | |
keep_start: 只保留点赞数大于keep_start的问答 | |
response_less_word: 答案至少要有response_less_word个字 | |
''' | |
file_names = [ | |
'/data/raw_data/web_text_zh_test.json', | |
'/data/raw_data/web_text_zh_train.json', | |
'/data/raw_data/web_text_zh_valid.json', | |
] | |
save_file_name = PROJECT_ROOT + '/data/my_data/my_web_text_zh.parquet' | |
# 后续append写入,存在文件先删除 | |
if exists(save_file_name): | |
assert delete_file(save_file_name) | |
def process_function(line: str) -> dict: | |
item = ujson.loads(line) | |
if item['star'] < keep_start or len(item['content']) < response_less_word: | |
return None | |
# 数据清洗 | |
# 去除重复的标点符号 | |
prompt = remove_duplicate_punctuation(item['title']) | |
response = remove_duplicate_punctuation(item['content']) | |
write_dict = { | |
"prompt": prompt, | |
"response": response, | |
} | |
return write_dict | |
for file_name in file_names: | |
read_file = PROJECT_ROOT + file_name | |
read_and_write_template(read_file, save_file_name, process_function) | |
def process_bake_qa(response_less_word: int=15) -> None: | |
''' | |
处理147万百度知道知识类数据集 | |
''' | |
file_names = [ | |
'/data/raw_data/baike_qa_train.json', | |
'/data/raw_data/baike_qa_valid.json', | |
] | |
save_file_name = PROJECT_ROOT + '/data/my_data/my_baike_qa.parquet' | |
# 后续append写入,存在文件先删除 | |
if exists(save_file_name): | |
assert delete_file(save_file_name) | |
def process_function(line: str) -> dict: | |
item = ujson.loads(line) | |
if len(item['answer']) < response_less_word: | |
return None | |
# 数据清洗 | |
prompt = '' | |
if get_sentences_dice_similarity(item['title'], item['desc']) >= 0.90: | |
# title 和desc 相似度过高,只用title作为问题 | |
prompt = item['title'] | |
else: | |
# title 和desc拼接形成问题 | |
prompt = "{}{}".format(item['title'], item['desc']) | |
# 删除\r | |
prompt = prompt.replace('\r','') | |
# 删除重复的标点符号 | |
prompt = remove_duplicate_punctuation(prompt) | |
# 去除重复的标点符号 | |
response = item['answer'].replace('\r','') | |
response = remove_duplicate_punctuation(response) | |
# 剔除问题和答案过短的数据 | |
if len(prompt) < 3 or len(response) < response_less_word: | |
return None | |
write_dict = { | |
"prompt": prompt, | |
"response": response, | |
} | |
return write_dict | |
for file_name in file_names: | |
read_file = PROJECT_ROOT + file_name | |
read_and_write_template(read_file, save_file_name, process_function) | |
def repair_line_error_csv_file(raw_csv_file: str, save_suffix: str, read_encoding: str='utf-8', ) -> None: | |
''' | |
修复csv文件,将文件中换行符替换为\n,字段中的英文字符替换为中文字符 | |
''' | |
with codecs.open(raw_csv_file, 'r', encoding=read_encoding, errors='ignore') as f: | |
reader = csv.reader(f) | |
new_lines = [] | |
for line in reader: | |
for i in range(len(line)): | |
line[i] = line[i].replace('\n', '\\n') # 处理异常的换行符 | |
line[i] = line[i].replace(',', ',') # 英文逗号换为中文逗号 | |
new_lines.append(line) | |
with open(raw_csv_file[: -4] + save_suffix, 'w', encoding='utf-8', newline="") as f: | |
writer = csv.writer(f) | |
writer.writerows(new_lines) | |
def process_chinese_medical_datasets(response_less_word: int=15) -> None: | |
''' | |
处理中国医药领域问答数据集 | |
''' | |
raw_dataset_dir = PROJECT_ROOT + '/data/raw_data/chinese_medical_dialogue_datasets' | |
raw_data_files = get_path_of_suffix_files(raw_dataset_dir, '.csv') | |
# 如果没有修复的文件,则修复csv文件换行异常 | |
suffix = '.repaired.csv' | |
need_to_repair_files = [ | |
file_name for file_name in raw_data_files \ | |
if not file_name.endswith(suffix) and file_name[0: -4] + suffix not in raw_data_files | |
] | |
# 修复异常换行的文件 | |
for file_name in need_to_repair_files: | |
repair_line_error_csv_file(file_name, suffix, read_encoding='gb2312') | |
# 重新获取原始文件(即修复后的文件) | |
raw_data_files = get_path_of_suffix_files(raw_dataset_dir, suffix) | |
# 获取要保存的文件名 | |
save_file = PROJECT_ROOT + '/data/my_data/my_chinese_medical_dialogue.parquet' | |
# for file_name in raw_data_files: | |
# file_name = file_name.split('/')[-1][0: -(len(suffix))] + '.parquet' | |
# file_name = PROJECT_ROOT + '/data/my_data/' + file_name | |
# save_files.append(file_name) | |
# 后续append写入,存在文件先删除 | |
if exists(save_file): | |
assert delete_file(save_file) | |
def process_function(line: str) -> dict: | |
# department,title,ask,answer | |
item = line.split(',') # csv文件逗号分割 | |
if len(item) < 4: | |
print(item) | |
return None | |
if len(item[3]) < response_less_word: | |
return None | |
# 数据清洗 | |
prompt = '' | |
if get_sentences_dice_similarity(item[1], item[2]) >= 0.90: | |
# title 和ask 相似度过高,只用ask作为问题 | |
prompt = item[2] | |
else: | |
# title 和 ask 拼接形成问题 | |
prompt = "{}{}".format(item[1], item[2]) | |
# 删除\r | |
prompt = prompt.replace('\r','') | |
# 删除重复的标点符号 | |
prompt = remove_duplicate_punctuation(prompt) | |
# 去除重复的标点符号 | |
response = ''.join(item[3: ]).replace('\r','') | |
response = remove_duplicate_punctuation(response) | |
# 剔除问题和答案过短的数据 | |
if len(prompt) < 3 or len(response) < response_less_word: | |
return None | |
write_dict = { | |
"prompt": prompt, | |
"response": response, | |
} | |
return write_dict | |
for i, file_name in enumerate(raw_data_files): | |
read_file = file_name | |
read_and_write_template(read_file, save_file, process_function) | |
def process_finace_dataset(prompt_less_word: int=10, response_less_word: int=15) -> None: | |
''' | |
处理金融问答数据集 | |
''' | |
finace_data_file = PROJECT_ROOT + '/data/raw_data/financezhidao_filter.csv' | |
suffix = '.repaired.csv' | |
if not exists(finace_data_file[0: -4] + suffix): | |
repair_line_error_csv_file(finace_data_file, save_suffix=suffix, read_encoding='utf-8') | |
def process_function(line: str) -> dict: | |
# title,prompt,reply,is_best | |
item = line.split(',') # csv文件逗号分割 | |
if len(item) < 4: | |
print(item) | |
return None | |
if len(item[0]) + len(item[1]) < prompt_less_word or len(item[2]) < response_less_word: | |
return None | |
# 数据清洗 | |
prompt = '' | |
if get_sentences_dice_similarity(item[0], item[1]) >= 0.90: | |
# title 和prompt 相似度过高,只用最长的作为问题 | |
prompt = item[0] if len(item[0]) > len(item[0]) else item[1] | |
else: | |
# title 和 ask 拼接形成问题 | |
prompt = "{}{}".format(item[0], item[1]) | |
# 删除\r | |
prompt = prompt.replace('\r','') | |
# 删除重复的标点符号 | |
prompt = remove_duplicate_punctuation(prompt) | |
# 去除重复的标点符号 | |
response = ''.join(item[2]).replace('\r','') | |
response = remove_duplicate_punctuation(response) | |
# 剔除问题和答案过短的数据 | |
if len(prompt) < prompt_less_word or len(response) < response_less_word: | |
return None | |
write_obj = { | |
"prompt": prompt, | |
"response": response, | |
} | |
return write_obj | |
read_file = finace_data_file[0: -4] + suffix | |
write_file = PROJECT_ROOT + '/data/my_data/' + read_file.split('/')[-1][0: -(len(suffix))] + '.parquet' | |
# 后续append写入,存在文件先删除 | |
if exists(write_file): | |
assert delete_file(write_file) | |
read_and_write_template(read_file, write_file, process_function) | |
def process_zhihu_kol_dataset(prompt_less_word: int=4, response_less_word: int=10, group_cnt: int=10000) -> None: | |
''' | |
处理知乎数据集 | |
''' | |
raw_zhihu_data_path = abspath(dirname(dirname(__file__))) + '/data/raw_data/zhihu-kol' | |
file_names = [] | |
suffix = '.parquet' | |
for root, _, files in walk(raw_zhihu_data_path): | |
for file in files: | |
if file.endswith(suffix): | |
file_names.append(root + '/' + file) | |
def process_function(sentence: str) -> str: | |
''' | |
针对一个句子的数据清洗 | |
''' | |
# 删除\r | |
sentence = sentence.replace('\r','') | |
# 删除重复的标点符号 | |
sentence = remove_duplicate_punctuation(sentence) | |
return sentence | |
# row keys :['INSTRUCTION', 'RESPONSE', 'SOURCE', 'METADATA'] | |
save_file = PROJECT_ROOT + '/data/my_data/zhihu_kol.parquet' | |
# 后续append写入,存在文件先删除 | |
if exists(save_file): | |
assert delete_file(save_file) | |
all_cnt, keep_cnt = 0, 0 | |
cur_rows = [] | |
append = cur_rows.append | |
for file in file_names: | |
pf = pq.read_table(file) | |
log.info('process file: {}'.format(file), save_to_file=True) | |
for prompt, response in progress.track(zip(pf['INSTRUCTION'], pf['RESPONSE']), total=pf.num_rows): | |
all_cnt += 1 | |
prompt, response = prompt.as_py(), response.as_py() | |
prompt = process_function(prompt) | |
response = process_function(response) | |
if len(prompt) < prompt_less_word or len(response) < response_less_word: | |
continue | |
keep_cnt += 1 | |
write_dict = { | |
'prompt': prompt, | |
'response': response, | |
} | |
append(write_dict) | |
if len(cur_rows) >= group_cnt: | |
df = pd.DataFrame(cur_rows) | |
write_single_parquet_file(save_file, df) | |
cur_rows = [] | |
append = cur_rows.append | |
# end for | |
if len(cur_rows) > 0: | |
df = pd.DataFrame(cur_rows) | |
write_single_parquet_file(save_file, df) | |
cur_rows = [] | |
log.info('save file to: {}, 全部数据共{}行,清洗后剩余{}行'.format(save_file, all_cnt, keep_cnt), save_to_file=True) | |
def process_belle_knowledge_enhanced_dataset(response_less_words: int=15, group_cnt: int=10000) -> None: | |
''' | |
处理belle开源的知识增强数据集 | |
''' | |
file_names = [ | |
'/data/raw_data/bell_open_source/train_2M_CN.json', | |
'/data/raw_data/bell_open_source/train_0.8M_CN.json', | |
'/data/raw_data/bell_open_source/Belle_open_source_1M.json', | |
] | |
save_file = PROJECT_ROOT + '/data/my_data/my_belll_3M_cn.parquet' | |
# 后续append写入,存在文件先删除 | |
if exists(save_file): | |
assert delete_file(save_file) | |
def process_function(line: str) -> dict: | |
''' | |
每行的处理函数 | |
''' | |
item = ujson.loads(line) | |
prompt = item['instruction'] | |
response = item['output'] | |
# 剔除翻译任务 | |
if '翻译' in prompt or 'translate' in prompt.lower(): | |
return None | |
# 删除表格类任务 | |
if '表格' in prompt or '-----' in prompt or '-----' in response: | |
return None | |
if len(response) < response_less_words: | |
return None | |
prompt = remove_duplicate_punctuation(prompt) | |
response = remove_duplicate_punctuation(response) | |
if len(response) < response_less_words: | |
return None | |
write_dict = { | |
'prompt': prompt, | |
'response': response | |
} | |
return write_dict | |
for file in file_names: | |
file = PROJECT_ROOT + file | |
read_and_write_template(file, save_file, process_function) | |
def convert_wiki_to_simple_zh(buffer_size: int=10000) -> None: | |
''' | |
将繁体wiki转换为简体Wiki | |
''' | |
raw_zh_wiki_file = PROJECT_ROOT + '/data/raw_data/wiki.txt' | |
save_zh_wiki_simple_file = PROJECT_ROOT + '/data/raw_data/wiki.simple.txt' | |
if exists(save_zh_wiki_simple_file): | |
assert delete_file(save_zh_wiki_simple_file) | |
cc = OpenCC('t2s') | |
cur_rows = [] | |
append = cur_rows.append | |
def procees_line(line: str) -> str: | |
''' | |
处理一行文本 | |
''' | |
# 将繁体转换为简体 | |
line = cc.convert(line) | |
line = re.sub(r"\「|\」|\「|\」|\『|\』", '\"', line) # 将「」「」『』这些符号替换成引号 | |
line = re.sub(r"\,\)|\;\)", ')', line) # 罗德·法尼(Rod Dodji Fanni,) | |
line = re.sub(r"\(\,|\(\,", '(', line) # 阿魯拉·基馬(Alula Girma (, | |
line = convert_en_punctuation_to_zh_punct(line) # 英文标点转换为中文标点 | |
line = remove_duplicate_punctuation(line) # 删除中文空括号和重复的标点 | |
return line | |
with progress.open(raw_zh_wiki_file, 'r', encoding='utf-8') as read_f: | |
with open(save_zh_wiki_simple_file, 'a', encoding='utf-8') as write_f: | |
for line in read_f: | |
line = procees_line(line) | |
if len(line.strip()) == 0: continue | |
line = '{}\n'.format(line) | |
append(line) | |
if len(cur_rows) >= buffer_size: | |
write_f.writelines(cur_rows) | |
cur_rows = [] | |
append = cur_rows.append | |
if len(cur_rows) > 0: | |
write_f.writelines(cur_rows) | |
cur_rows = [] | |
def process_zh_wiki_data_to_datset(groups_cnt: int=10000, max_len: int=512, seed: int=23333) -> None: | |
''' | |
将Wiki中文数转换为问答数据集 | |
wiki 下载地址:https://dumps.wikimedia.org/zhwiki/ | |
将下载的bz2文件转换为wiki.txt参考:https://github.com/apertium/WikiExtractor | |
''' | |
raw_zh_wiki_file = PROJECT_ROOT + '/data/raw_data/wiki.txt' | |
zhwiki_simple_file = PROJECT_ROOT + '/data/my_data/wiki_zh_simple.parquet' | |
# 删除已经存在的数据 | |
if exists(zhwiki_simple_file): | |
assert delete_file(zhwiki_simple_file) | |
# 将繁体转换为简体 | |
cc = OpenCC('t2s') | |
all_cnt, keep_cnt = 0, 0 | |
# 构造问题的前缀 | |
prompt_prefix = [ | |
'什么是{}?', | |
'介绍一下{}', | |
'介绍一下什么是{}', | |
'写一篇关于{}的介绍', | |
'{}是什么?', | |
'你知道{}吗?', | |
'生成关于{}的介绍', | |
'我想知道关于{}的详细信息', | |
'你了解{}吗?', | |
'请解释一下{}', | |
'对于{},你有什么了解或看法吗?', | |
'请告诉我关于{}的信息', | |
'请简要描述一下{}', | |
'请提供有关{}的一些详细信息', | |
'能否解释一下{}是什么?', | |
'请分享一些关于{}的背景知识', | |
'请简要概括一下{}', | |
'能给我一些关于{}的背景资料吗?', | |
'有关{}的信息可以分享一下吗?', | |
'你能告诉我{}是什么吗?', | |
] | |
def procees_line(line: str) -> str: | |
''' | |
处理一行文本 | |
''' | |
# 将繁体转换为简体 | |
line = cc.convert(line) | |
line = re.sub(r"\「|\」|\「|\」|\『|\』", '\"', line) # 将「」「」『』这些符号替换成引号 | |
line = re.sub(r"\,\)|\;\)", ')', line) # 罗德·法尼(Rod Dodji Fanni,) | |
line = re.sub(r"\(\,|\(\,", '(', line) # 阿魯拉·基馬(Alula Girma (, | |
line = convert_en_punctuation_to_zh_punct(line) # 英文标点转换为中文标点 | |
line = remove_duplicate_punctuation(line) # 删除中文空括号和重复的标点 | |
return line | |
np.random.seed(seed) | |
choice = np.random.choice | |
with progress.open(raw_zh_wiki_file, 'r', encoding='utf-8') as read_file: | |
prompt = '' | |
response = '' | |
pre_line_len = 0 | |
cur_rows = [] | |
append = cur_rows.append | |
for line in read_file: | |
all_cnt += 1 | |
# prompt已经保存,但是仍有多余的行,这些行使得response的长度>max_len,故跳过,不处理 | |
if len(prompt) == 0 and pre_line_len > 0: | |
pre_line_len = len(line.strip()) | |
continue | |
# 清洗一行 | |
line = procees_line(line) | |
# 确定问题,pre_line_len是0,既是上一行是空行,则当前行是新的百科词条,设置为prompt | |
if prompt == '' and line.endswith(':') and pre_line_len == 0: | |
prompt = choice(prompt_prefix).format(line[0: -1]) | |
continue | |
pre_line_len = len(line.strip()) | |
# 问题下来若干行为答案 | |
if prompt != '' and not line.endswith(':'): | |
# 其实,pre_line_len已经是len(line.strip())了,如果len(line.strip())=0,既是当前行是0,则不管答案长度够不够,都需要保存了 | |
if len(response) + len(line) <= max_len and pre_line_len != 0: | |
response = '{}{}'.format(response, line) | |
elif len(response) + len(line) > max_len or pre_line_len == 0: | |
# 长度超了或者当前的百科已经结束,保存一条样例 | |
keep_cnt += 1 | |
response = '{}{}'.format(response, line) | |
append({'prompt': prompt, 'response': ''.join(response[0: max_len])}) | |
prompt = '' | |
response = '' | |
# =groups_cnt保存到文件 | |
if len(cur_rows) >= groups_cnt: | |
df = pd.DataFrame(cur_rows) | |
write_single_parquet_file(zhwiki_simple_file, df) | |
cur_rows = [] | |
append = cur_rows.append | |
# end for | |
if len(prompt) > 0 and len(response) > 0: | |
keep_cnt += 1 | |
append({'prompt': prompt, 'response': response}) | |
if len(cur_rows) > 0: | |
df = pd.DataFrame(cur_rows) | |
write_single_parquet_file(zhwiki_simple_file, df) | |
cur_rows = [] | |
log.info("merge into file: {}, 全部数据共{}行,清洗后剩余{}行".format(zhwiki_simple_file, all_cnt, keep_cnt), save_to_file=True) | |
def merge_dataset_as_single_file(groups_cnt: int=50000, max_len: int=512, min_len: int=3, cut_max_len: bool=False) -> None: | |
''' | |
将多个数据集合并为一个数据集 | |
''' | |
from_parquet_files = get_path_of_suffix_files(PROJECT_ROOT + '/data/my_data', '.parquet') | |
save_file = PROJECT_ROOT + '/data/my_dataset.parquet' | |
# 后续append写入,存在文件先删除 | |
if exists(save_file): | |
assert delete_file(save_file) | |
cur_rows = [] | |
append = cur_rows.append | |
all_cnt, keep_cnt = 0, 0 | |
for file in from_parquet_files: | |
print('process file: {}'.format(file)) | |
parquet_table = pq.read_table(file) | |
for prompt, response in progress.track(zip(parquet_table['prompt'], parquet_table['response']), total=parquet_table.num_rows): | |
prompt, response = prompt.as_py(), response.as_py() | |
all_cnt += 1 | |
if len(prompt) < min_len or len(response) < min_len: | |
continue | |
if cut_max_len and (len(prompt) > max_len or len(response) > max_len): | |
prompt = prompt[0: max_len] | |
response = response[0: max_len] | |
keep_cnt += 1 | |
append({'prompt': prompt , 'response': response}) | |
if len(cur_rows) >= groups_cnt: | |
df = pd.DataFrame(cur_rows) | |
write_single_parquet_file(save_file, df) | |
cur_rows = [] | |
append = cur_rows.append | |
# 处理末尾部分 | |
if len(cur_rows) > 0: | |
df = pd.DataFrame(cur_rows) | |
write_single_parquet_file(save_file, df) | |
cur_rows = [] | |
log.info("merge into file: {}, 全部数据共{}行,清洗后剩余{}行".format(save_file, all_cnt, keep_cnt), save_to_file=True) | |
def remove_dataset_duplicate_rows(groups_cnt: int=50000) -> None: | |
''' | |
使用mini_hash删除数据集中重复的部分 | |
''' | |
from_parquet_files = PROJECT_ROOT + '/data/my_dataset.parquet' | |
save_file = PROJECT_ROOT + '/data/my_dataset_no_dulpticates.parquet' | |
# 后续append写入,存在文件先删除 | |
if exists(save_file): | |
assert delete_file(save_file) | |
cur_rows = [] | |
all_cnt, keep_cnt = 0, 0 | |
row_index = -1 | |
drop_dataset_duplicate = DropDatasetDuplicate(threshold=0.85, num_perm=256) | |
parquet_table = pq.read_table(from_parquet_files) | |
all_cnt = parquet_table.num_rows | |
# 先顺序遍历获取哪些行是重复的 | |
for prompt, response in progress.track(zip(parquet_table['prompt'], parquet_table['response']), total=parquet_table.num_rows): | |
row_index += 1 | |
doc = f"{prompt.as_py()}{response.as_py()}" | |
drop_dataset_duplicate.add_doc(index=row_index, doc=doc) | |
row_index = -1 | |
need_to_drop_indexs = drop_dataset_duplicate.get_duplicate_indexs() | |
# 再顺序遍历一遍,重复的行不添加到新的数据集 | |
for prompt, response in progress.track(zip(parquet_table['prompt'], parquet_table['response']), total=parquet_table.num_rows): | |
row_index += 1 # 不管有没有跳过行, row_index都必须+1 | |
# 重复的行跳过 | |
if row_index in need_to_drop_indexs: | |
continue | |
cur_rows.append({'prompt': prompt.as_py() , 'response': response.as_py()}) | |
keep_cnt += 1 | |
if len(cur_rows) >= groups_cnt: | |
df = pd.DataFrame(cur_rows) | |
write_single_parquet_file(save_file, df) | |
cur_rows = [] | |
# 处理末尾部分 | |
if len(cur_rows) > 0: | |
df = pd.DataFrame(cur_rows) | |
write_single_parquet_file(save_file, df) | |
log.info("merge into file: {}, 全部数据共{}行,文档去重后剩余{}行".format(save_file, all_cnt, keep_cnt), save_to_file=True) | |
def shuffle_parquet_dataset(parquet_file: str, shuffle_file: str, seed: int=23333, groups_cnt: int=65536) -> None: | |
''' | |
打乱一个parquet文件数据集 | |
''' | |
if not exists(parquet_file): | |
raise Exception('can not find parquet file: {}'.format(parquet_file)) | |
print('start shuffle...') | |
pf = pq.read_table(parquet_file) | |
df = pf.to_pandas() | |
df = df.sample(frac=1.0, replace=False, random_state=seed, axis=0) | |
if exists(shuffle_file): | |
assert delete_file(shuffle_file) | |
# 分块写入parquet,否则小内存读取直接OOM | |
n = len(df) | |
for i in range(0, n, groups_cnt): | |
cur_group_df = df[i: i + groups_cnt] | |
write_single_parquet_file(shuffle_file, cur_group_df) | |
def count_my_json_data() -> None: | |
''' | |
统计目前的所有数据集数据量 | |
''' | |
my_data_files = get_path_of_suffix_files(PROJECT_ROOT + '/data/my_data', '.json') | |
result = [['file_name', 'count']] | |
all_cnt = 0 | |
for file in my_data_files: | |
file_name = file.split('/')[-1] | |
cur_cnt = 0 | |
with progress.open(file, 'r', encoding='utf-8') as f: | |
for _ in f: | |
cur_cnt += 1 | |
all_cnt += cur_cnt | |
result.append([file_name, cur_cnt]) | |
result.append(['汇总', all_cnt]) | |
log.info(str(result), save_to_file=True) | |
console = Console() | |
table = Table(show_header=True, show_lines=True,) | |
for col in result[0]: | |
table.add_column(col) | |
for i in range(1, len(result)): # 跳过表头 | |
table.add_row(str(result[i][0]), str(result[i][1])) | |
console.print(table) | |
def count_my_parquet_data(parquet_file: str=None) -> None: | |
''' | |
统计dir目录下所有parquet数据集数据量 | |
''' | |
my_data_files = [] | |
if not parquet_file: | |
my_data_files = get_path_of_suffix_files(PROJECT_ROOT + '/data/my_data', '.parquet') | |
elif isdir(parquet_file): | |
my_data_files = get_path_of_suffix_files(parquet_file, '.parquet') | |
elif parquet_file.endswith('.parquet'): | |
my_data_files = [parquet_file] | |
result = [['file_name', 'count']] | |
all_cnt = 0 | |
for file in my_data_files: | |
file_name = file.split('/')[-1] | |
cur_cnt = 0 | |
pf = ParquetFile(file) | |
for pf_chunk in pf: | |
cur_cnt += pf_chunk.info['rows'] | |
all_cnt += cur_cnt | |
result.append([file_name, cur_cnt]) | |
result.append(['汇总', all_cnt]) | |
log.info(str(result), save_to_file=True) | |
console = Console() | |
table = Table(show_header=True, show_lines=True,) | |
for col in result[0]: | |
table.add_column(col) | |
for i in range(1, len(result)): # 跳过表头 | |
table.add_row(str(result[i][0]), str(result[i][1])) | |
console.print(table) | |
def split_train_valid_test_datasets(source_parquet_file: str, max_len: int=320, seed: int=23333, train_ratio: float=0.91, test_ratio: float=0.0875, valid_ratio: float=0.0025, groups_cnt: int=50000) -> None: | |
''' | |
将原始数据拆分为训练集、测试集和验证集 | |
''' | |
assert train_ratio + test_ratio + valid_ratio == 1.0 | |
train_parquet_file = PROJECT_ROOT + '/data/my_train_dataset.parquet' | |
test_parquet_file = PROJECT_ROOT + '/data/my_test_dataset.parquet' | |
valid_parquet_file = PROJECT_ROOT + '/data/my_valid_dataset.parquet' | |
if exists(train_parquet_file): assert delete_file(train_parquet_file) | |
if exists(test_parquet_file): assert delete_file(test_parquet_file) | |
if exists(valid_parquet_file): assert delete_file(valid_parquet_file) | |
np.random.seed(seed) | |
train, test, valid = [], [], [] | |
parquet_table = pq.read_table(source_parquet_file) | |
for prompt, response in progress.track(zip(parquet_table['prompt'], parquet_table['response']), total=parquet_table.num_rows): | |
prompt, response = prompt.as_py(), response.as_py() | |
rand = np.random.random() | |
cur_data = {'prompt': ''.join(prompt[0: max_len]) , 'response': ''.join(response[0: max_len])} | |
if 0 <= rand < train_ratio: | |
train.append(cur_data) | |
elif train_ratio <= rand < train_ratio + test_ratio: | |
test.append(cur_data) | |
else: | |
valid.append(cur_data) | |
if len(train) >= groups_cnt: | |
write_single_parquet_file(train_parquet_file, pd.DataFrame(train)) | |
train = [] | |
if len(test) >= groups_cnt: | |
write_single_parquet_file(test_parquet_file, pd.DataFrame(test)) | |
test = [] | |
if len(valid) >= groups_cnt: | |
write_single_parquet_file(valid_parquet_file, pd.DataFrame(valid)) | |
valid = [] | |
if len(train) > 0: | |
write_single_parquet_file(train_parquet_file, pd.DataFrame(train)) | |
train = [] | |
if len(test) > 0: | |
write_single_parquet_file(test_parquet_file, pd.DataFrame(test)) | |
test = [] | |
if len(valid) > 0: | |
write_single_parquet_file(valid_parquet_file, pd.DataFrame(valid)) | |
valid = [] | |
def parquet_to_text(sep='[SEP]', buffer_size: int=50000) -> None: | |
''' | |
将parquet文件转换为txt预料,句子之间用sep隔开 | |
txt文件用于训练tokenizer,使用huggingface的BPE训练会导致OOM | |
''' | |
parquet_file = PROJECT_ROOT + '/data/my_dataset.parquet' | |
txt_file = PROJECT_ROOT + '/data/my_corpus.txt' | |
if exists(txt_file): | |
assert delete_file(txt_file) | |
source_pf = ParquetFile(parquet_file) | |
cur_rows = [] | |
append = cur_rows.append | |
with open(txt_file, 'a', encoding='utf-8') as f_write: | |
for pf_chunk in progress.track(source_pf): | |
for rows in pf_chunk.iter_row_groups(): | |
for prompt, response in zip(rows['prompt'], rows['response']): | |
append(prompt + sep + response + sep + '\n') | |
if len(cur_rows) >= buffer_size: | |
f_write.writelines(cur_rows) | |
cur_rows = [] | |
append = cur_rows.append | |
# end for | |
if len(cur_rows) > 0: | |
f_write.writelines(cur_rows) | |
cur_rows = [] | |
def parquet_to_json() -> None: | |
''' | |
将parquet文件转换为json | |
''' | |
parquet_file = PROJECT_ROOT + '/data/my_finetune_data_zh.parquet' | |
json_file = PROJECT_ROOT + '/data/sft_train.json' | |
if exists(json_file): | |
assert delete_file(json_file) | |
source_pf = ParquetFile(parquet_file) | |
cur_rows = [] | |
append = cur_rows.append | |
for pf_chunk in progress.track(source_pf): | |
for rows in pf_chunk.iter_row_groups(): | |
for prompt, response in zip(rows['prompt'], rows['response']): | |
if len(response) == 0 or len(prompt) == 0: continue | |
append({ | |
'prompt': str(prompt), | |
'response': str(response), | |
}) | |
with open(json_file, 'w', encoding='utf-8') as f: | |
ujson.dump(cur_rows, f, indent=4, ensure_ascii=False) | |
def dataset_length_cnt() -> None: | |
dataset_file = PROJECT_ROOT + '/data/my_dataset.shuffle.parquet' | |
parquet_table = pq.read_table(dataset_file) | |
que_len_dict, ans_len_dict = defaultdict(int), defaultdict(int) | |
for prompt, response in progress.track(zip(parquet_table['prompt'], parquet_table['response']), total=parquet_table.num_rows): | |
prompt, response = prompt.as_py(), response.as_py() | |
que_len_dict[len(prompt)] += 1 | |
ans_len_dict[len(response)] += 1 | |
que_len, ans_len = [], [] | |
for k, v in que_len_dict.items(): | |
que_len.append([k, v]) | |
for k, v in ans_len_dict.items(): | |
ans_len.append([k, v]) | |
def gather_gt_x(array: list[tuple], x: int=512) -> list: | |
''' | |
长度大于x的合并在一起 | |
''' | |
new_array = [] | |
gt_x_cnt = 0 | |
for item in array: | |
if item[0] < x: | |
new_array.append([item[0], item[1]]) | |
else: | |
gt_x_cnt += item[1] | |
new_array.append([x, gt_x_cnt]) | |
return new_array | |
max_len = 512 | |
ans_list = gather_gt_x(ans_len, max_len) | |
ans_list.sort(key=lambda x: x[0]) | |
que_list = gather_gt_x(que_len, max_len) | |
que_list.sort(key=lambda x: x[0]) | |
ans_pd = pd.DataFrame(ans_list, columns=['length', 'count']) | |
que_pd = pd.DataFrame(que_list, columns=['length', 'count']) | |
def plot_sub_bar(plt, x, y, title: str, color: str='g') ->None: | |
plt.bar(x, y, color=color, label='sample count') | |
plt.ticklabel_format(style='sci',scilimits=(0,0), axis='y') | |
plt.legend() | |
plt.xlabel('length') | |
plt.ylabel('count') | |
plt.title(title) | |
plt.figure(figsize=(10, 10),dpi=200) | |
plt.subplot(2, 2, 1) | |
plot_sub_bar(plt, que_pd['length'], que_pd['count'], title='prompt length', color='c') | |
plt.subplot(2, 2, 2) | |
plot_sub_bar(plt, ans_pd['length'], ans_pd['count'], title='response length', color='g') | |
le512_pd = ans_pd[ans_pd['length'] < 512] | |
plt.subplot(2, 2, 3) | |
plot_sub_bar(plt, le512_pd['length'], le512_pd['count'], title='response length < 512', color='limegreen') | |
le320_pd = ans_pd[ans_pd['length'] < 320] | |
plt.subplot(2, 2, 4) | |
plot_sub_bar(plt, le320_pd['length'], le320_pd['count'], title='response length < 320', color='limegreen') | |
plt.savefig(PROJECT_ROOT + '/img/sentence_length.png') | |
plt.show() | |
def process_belle_knowledge_enhanced_dataset_for_finetune(max_len: int=320, group_cnt: int=50000) -> None: | |
''' | |
处理belle开源的知识增强数据集 | |
''' | |
file_names = [ | |
'/data/raw_data/bell_open_source/Belle_open_source_0.5M.json', | |
'/data/raw_data/bell_open_source/train_conv_2.json', | |
'/data/raw_data/bell_open_source/generated_chat_0.4M.json', | |
] | |
save_file = PROJECT_ROOT + '/data/my_finetune_data_zh.parquet' | |
# 后续append写入,存在文件先删除 | |
if exists(save_file): | |
assert delete_file(save_file) | |
def process_function(line: str) -> dict: | |
''' | |
每行的处理函数 | |
''' | |
item = ujson.loads(line) | |
prompt = item['instruction'] | |
response = item['output'] | |
# 剔除翻译任务 | |
if 'translate' in prompt.lower(): return None | |
for word in ('翻译', '英译', '译英', '中译', '译中', '汉译', '译汉'): | |
if word in prompt: | |
return None | |
# 删除表格类任务 | |
if '表格' in prompt or '-----' in prompt or '-----' in response: | |
return None | |
if len(prompt) > max_len or len(response) > max_len: | |
return None | |
write_dict = { | |
'prompt': prompt, | |
'response': response | |
} | |
return write_dict | |
for file in file_names: | |
file = PROJECT_ROOT + file | |
read_and_write_template(file, save_file, process_function) | |
if __name__ == '__main__': | |
processed_file_dir = PROJECT_ROOT + '/data/my_data' | |
if not exists(processed_file_dir): | |
mkdir(processed_file_dir) | |
# 注释了,不重复处理 | |
# 1. | |
# process_web_text(keep_start=5, response_less_word=15) | |
# 2. | |
# process_bake_qa(response_less_word=15) | |
# 3. | |
# process_chinese_medical_datasets(response_less_word=15) | |
# 4. 金融问答数据集质量太差了 | |
# process_finace_dataset(prompt_less_word=10, response_less_word=15) | |
# 5. | |
# process_zhihu_kol_dataset(prompt_less_word=4, response_less_word=10) | |
# 6. | |
# process_belle_knowledge_enhanced_dataset(response_less_words=5) | |
# convert_wiki_to_simple_zh() | |
# 7. | |
# process_zh_wiki_data_to_datset(groups_cnt=10000, max_len=512) | |
#================================================================= | |
# merge | |
# merge_dataset_as_single_file(groups_cnt=50000, min_len=3, max_len=512, cut_max_len=True) | |
remove_dataset_duplicate_rows(groups_cnt=50000) | |
# # shuffle | |
# shuffle_parquet_dataset( | |
# parquet_file=PROJECT_ROOT + '/data/my_dataset.parquet', | |
# shuffle_file=PROJECT_ROOT + '/data/my_dataset.shuffle.parquet', | |
# seed=23333 | |
# ) | |
# split train validated and test | |
# split_train_valid_test_datasets( | |
# source_parquet_file=PROJECT_ROOT + '/data/my_dataset.shuffle.parquet', | |
# max_len=320, | |
# groups_cnt=50000 | |
# ) | |
# parquet_to_text() | |
# count_my_parquet_data(PROJECT_ROOT + '/data/my_dataset.parquet') | |
# dataset_length_cnt() | |
# process_belle_knowledge_enhanced_dataset_for_finetune(max_len=320, group_cnt=50000) | |
# count_my_parquet_data(PROJECT_ROOT + '/data/') | |
parquet_to_json() | |
# count_my_json_data() | |