import ujson import re from os.path import dirname, abspath, exists, isdir from os import remove, mkdir, walk import time from collections import defaultdict from matplotlib import pyplot as plt import codecs, csv import pandas as pd import numpy as np from rich import progress from rich.table import Table from rich.console import Console from fastparquet import ParquetFile, write import pyarrow.parquet as pq from opencc import OpenCC import sys sys.path.extend(['.','..']) from logger import Logger from config import PROJECT_ROOT from utils.functions import get_path_of_suffix_files, DropDatasetDuplicate log = Logger('data_process', save2file=True, file_name=PROJECT_ROOT + '/logs/raw_data_process.log') punctuation = set("!\"#$%&'()*+,-./:;<=>?@[\]^_`{|}~.,;《》?!“”‘’@#¥%…&×()——+【】{};;●,。&~、|\s::\n") en_punctuation = ",().!;:" zh_punctuation = ",()。!;:" def delete_file(file: str)-> bool: ''' 询问删除文件 ''' if exists(file): ans = input('delete file: {} ? Yes (y) or No (n)'.format(file)) ans = ans.lower() if ans in ('yes', 'y'): remove(file) print('deleted.') return True return False def remove_duplicate_punctuation(sentence: str) -> str: ''' 删除句子中重复的标点符号、重复的空格,同时将换行变为特殊字符'\n' ''' # 将空格(全角空格)替换为逗号, 可能会有重复的空客,下面删除重复标点会删除 sentence = re.sub(' | ', ',', sentence) ans = '' n = len(sentence) p = 0 while p < n: ans += sentence[p] while p + 1 < n and sentence[p] in punctuation and sentence[p + 1] in punctuation: p += 1 p += 1 return ans def convert_en_punctuation_to_zh_punct(sentence: str) -> str: ''' 将句子中的英文标点替换文中文标点 ''' n = len(zh_punctuation) for i in range(n): sentence = sentence.replace(en_punctuation[i], zh_punctuation[i]) return sentence def get_sentences_dice_similarity(st_a: str, st_b: str) -> float: ''' 获取两个句子的Dice相似度(Dice similarity) s(a, b) = 2 * len( set(a) & set(b) ) / (len(set(a)) + len(set(b))) ''' set_a, set_b = set(st_a), set(st_b) total_len = len(set_a) + len(set_b) if total_len == 0: return 0.0 inter_set = set_a & set_b return ( 2 * len(inter_set)) / total_len def write_single_parquet_file(file_name: str, data_frame: pd.DataFrame) -> None: ''' 将dataframe写到单独的parquet file中 ''' append = False if exists(file_name): append = True write(file_name, data_frame, compression='GZIP',append=append) def read_and_write_template(read_file: str, write_to_file: str, call_back: object, group_cnt: int=10000) -> None: ''' 处理数据读写模板,需要提供一个回调函数call_back, read_file: 原始数据文件 write_to_file:处理后的要保存数据文件 call_back:函数输入一个字符串,输出一个处理后的字典dict,如果输入的字符串为无效数据,请返回None group_cnt: parquet file分割行数 如: >>> def call_back(inputs: str) -> dict: >>> if check(inputs) not valid: >>> return None ... ... do something for inputs ... >>> my_dict = { >>> 'prompt': inputs['p'], >>> 'response': inputs['a1'] + inputs['a2'], >>> ... >>> } >>> return my_dict ''' log.info('process file:{}'.format(read_file), save_to_file=True) start = time.time() raw_line_cnt = 0 keep_line_cnt = 0 with progress.open(read_file, 'r', encoding='utf-8') as f_read: cur_rows = [] append = cur_rows.append for line in f_read: try: raw_line_cnt += 1 write_dict = call_back(line) if write_dict is None: continue keep_line_cnt += 1 append(write_dict) # ujson.dump(write_obj, f_write, indent=4, ensure_ascii=False) # ujson.dump(write_obj, f_write, ensure_ascii=False,) # f_write.write('\n') if len(cur_rows) >= group_cnt: df = pd.DataFrame(cur_rows) write_single_parquet_file(write_to_file, df) cur_rows = [] append = cur_rows.append except Exception as e: # log.error('处理文件异常:{}, content:{}'.format(str(e), line)) print(line) raise e # end for # 处理末尾部分 if len(cur_rows) > 0: df = pd.DataFrame(cur_rows) write_single_parquet_file(write_to_file, df) cur_rows = [] end = time.time() log.info('原始文件:{},共{}行,处理后剩余{}行,保存到文件:{}。耗时:{:.6}s'\ .format(read_file, raw_line_cnt, keep_line_cnt, write_to_file, end - start), save_to_file=True) #=====================================数据集处理================================= def process_web_text(keep_start: int=5, response_less_word: int=10) -> None: ''' 处理425万社区问答webtext2019zh知识类数据集 keep_start: 只保留点赞数大于keep_start的问答 response_less_word: 答案至少要有response_less_word个字 ''' file_names = [ '/data/raw_data/web_text_zh_test.json', '/data/raw_data/web_text_zh_train.json', '/data/raw_data/web_text_zh_valid.json', ] save_file_name = PROJECT_ROOT + '/data/my_data/my_web_text_zh.parquet' # 后续append写入,存在文件先删除 if exists(save_file_name): assert delete_file(save_file_name) def process_function(line: str) -> dict: item = ujson.loads(line) if item['star'] < keep_start or len(item['content']) < response_less_word: return None # 数据清洗 # 去除重复的标点符号 prompt = remove_duplicate_punctuation(item['title']) response = remove_duplicate_punctuation(item['content']) write_dict = { "prompt": prompt, "response": response, } return write_dict for file_name in file_names: read_file = PROJECT_ROOT + file_name read_and_write_template(read_file, save_file_name, process_function) def process_bake_qa(response_less_word: int=15) -> None: ''' 处理147万百度知道知识类数据集 ''' file_names = [ '/data/raw_data/baike_qa_train.json', '/data/raw_data/baike_qa_valid.json', ] save_file_name = PROJECT_ROOT + '/data/my_data/my_baike_qa.parquet' # 后续append写入,存在文件先删除 if exists(save_file_name): assert delete_file(save_file_name) def process_function(line: str) -> dict: item = ujson.loads(line) if len(item['answer']) < response_less_word: return None # 数据清洗 prompt = '' if get_sentences_dice_similarity(item['title'], item['desc']) >= 0.90: # title 和desc 相似度过高,只用title作为问题 prompt = item['title'] else: # title 和desc拼接形成问题 prompt = "{}{}".format(item['title'], item['desc']) # 删除\r prompt = prompt.replace('\r','') # 删除重复的标点符号 prompt = remove_duplicate_punctuation(prompt) # 去除重复的标点符号 response = item['answer'].replace('\r','') response = remove_duplicate_punctuation(response) # 剔除问题和答案过短的数据 if len(prompt) < 3 or len(response) < response_less_word: return None write_dict = { "prompt": prompt, "response": response, } return write_dict for file_name in file_names: read_file = PROJECT_ROOT + file_name read_and_write_template(read_file, save_file_name, process_function) def repair_line_error_csv_file(raw_csv_file: str, save_suffix: str, read_encoding: str='utf-8', ) -> None: ''' 修复csv文件,将文件中换行符替换为\n,字段中的英文字符替换为中文字符 ''' with codecs.open(raw_csv_file, 'r', encoding=read_encoding, errors='ignore') as f: reader = csv.reader(f) new_lines = [] for line in reader: for i in range(len(line)): line[i] = line[i].replace('\n', '\\n') # 处理异常的换行符 line[i] = line[i].replace(',', ',') # 英文逗号换为中文逗号 new_lines.append(line) with open(raw_csv_file[: -4] + save_suffix, 'w', encoding='utf-8', newline="") as f: writer = csv.writer(f) writer.writerows(new_lines) def process_chinese_medical_datasets(response_less_word: int=15) -> None: ''' 处理中国医药领域问答数据集 ''' raw_dataset_dir = PROJECT_ROOT + '/data/raw_data/chinese_medical_dialogue_datasets' raw_data_files = get_path_of_suffix_files(raw_dataset_dir, '.csv') # 如果没有修复的文件,则修复csv文件换行异常 suffix = '.repaired.csv' need_to_repair_files = [ file_name for file_name in raw_data_files \ if not file_name.endswith(suffix) and file_name[0: -4] + suffix not in raw_data_files ] # 修复异常换行的文件 for file_name in need_to_repair_files: repair_line_error_csv_file(file_name, suffix, read_encoding='gb2312') # 重新获取原始文件(即修复后的文件) raw_data_files = get_path_of_suffix_files(raw_dataset_dir, suffix) # 获取要保存的文件名 save_file = PROJECT_ROOT + '/data/my_data/my_chinese_medical_dialogue.parquet' # for file_name in raw_data_files: # file_name = file_name.split('/')[-1][0: -(len(suffix))] + '.parquet' # file_name = PROJECT_ROOT + '/data/my_data/' + file_name # save_files.append(file_name) # 后续append写入,存在文件先删除 if exists(save_file): assert delete_file(save_file) def process_function(line: str) -> dict: # department,title,ask,answer item = line.split(',') # csv文件逗号分割 if len(item) < 4: print(item) return None if len(item[3]) < response_less_word: return None # 数据清洗 prompt = '' if get_sentences_dice_similarity(item[1], item[2]) >= 0.90: # title 和ask 相似度过高,只用ask作为问题 prompt = item[2] else: # title 和 ask 拼接形成问题 prompt = "{}{}".format(item[1], item[2]) # 删除\r prompt = prompt.replace('\r','') # 删除重复的标点符号 prompt = remove_duplicate_punctuation(prompt) # 去除重复的标点符号 response = ''.join(item[3: ]).replace('\r','') response = remove_duplicate_punctuation(response) # 剔除问题和答案过短的数据 if len(prompt) < 3 or len(response) < response_less_word: return None write_dict = { "prompt": prompt, "response": response, } return write_dict for i, file_name in enumerate(raw_data_files): read_file = file_name read_and_write_template(read_file, save_file, process_function) def process_finace_dataset(prompt_less_word: int=10, response_less_word: int=15) -> None: ''' 处理金融问答数据集 ''' finace_data_file = PROJECT_ROOT + '/data/raw_data/financezhidao_filter.csv' suffix = '.repaired.csv' if not exists(finace_data_file[0: -4] + suffix): repair_line_error_csv_file(finace_data_file, save_suffix=suffix, read_encoding='utf-8') def process_function(line: str) -> dict: # title,prompt,reply,is_best item = line.split(',') # csv文件逗号分割 if len(item) < 4: print(item) return None if len(item[0]) + len(item[1]) < prompt_less_word or len(item[2]) < response_less_word: return None # 数据清洗 prompt = '' if get_sentences_dice_similarity(item[0], item[1]) >= 0.90: # title 和prompt 相似度过高,只用最长的作为问题 prompt = item[0] if len(item[0]) > len(item[0]) else item[1] else: # title 和 ask 拼接形成问题 prompt = "{}{}".format(item[0], item[1]) # 删除\r prompt = prompt.replace('\r','') # 删除重复的标点符号 prompt = remove_duplicate_punctuation(prompt) # 去除重复的标点符号 response = ''.join(item[2]).replace('\r','') response = remove_duplicate_punctuation(response) # 剔除问题和答案过短的数据 if len(prompt) < prompt_less_word or len(response) < response_less_word: return None write_obj = { "prompt": prompt, "response": response, } return write_obj read_file = finace_data_file[0: -4] + suffix write_file = PROJECT_ROOT + '/data/my_data/' + read_file.split('/')[-1][0: -(len(suffix))] + '.parquet' # 后续append写入,存在文件先删除 if exists(write_file): assert delete_file(write_file) read_and_write_template(read_file, write_file, process_function) def process_zhihu_kol_dataset(prompt_less_word: int=4, response_less_word: int=10, group_cnt: int=10000) -> None: ''' 处理知乎数据集 ''' raw_zhihu_data_path = abspath(dirname(dirname(__file__))) + '/data/raw_data/zhihu-kol' file_names = [] suffix = '.parquet' for root, _, files in walk(raw_zhihu_data_path): for file in files: if file.endswith(suffix): file_names.append(root + '/' + file) def process_function(sentence: str) -> str: ''' 针对一个句子的数据清洗 ''' # 删除\r sentence = sentence.replace('\r','') # 删除重复的标点符号 sentence = remove_duplicate_punctuation(sentence) return sentence # row keys :['INSTRUCTION', 'RESPONSE', 'SOURCE', 'METADATA'] save_file = PROJECT_ROOT + '/data/my_data/zhihu_kol.parquet' # 后续append写入,存在文件先删除 if exists(save_file): assert delete_file(save_file) all_cnt, keep_cnt = 0, 0 cur_rows = [] append = cur_rows.append for file in file_names: pf = pq.read_table(file) log.info('process file: {}'.format(file), save_to_file=True) for prompt, response in progress.track(zip(pf['INSTRUCTION'], pf['RESPONSE']), total=pf.num_rows): all_cnt += 1 prompt, response = prompt.as_py(), response.as_py() prompt = process_function(prompt) response = process_function(response) if len(prompt) < prompt_less_word or len(response) < response_less_word: continue keep_cnt += 1 write_dict = { 'prompt': prompt, 'response': response, } append(write_dict) if len(cur_rows) >= group_cnt: df = pd.DataFrame(cur_rows) write_single_parquet_file(save_file, df) cur_rows = [] append = cur_rows.append # end for if len(cur_rows) > 0: df = pd.DataFrame(cur_rows) write_single_parquet_file(save_file, df) cur_rows = [] log.info('save file to: {}, 全部数据共{}行,清洗后剩余{}行'.format(save_file, all_cnt, keep_cnt), save_to_file=True) def process_belle_knowledge_enhanced_dataset(response_less_words: int=15, group_cnt: int=10000) -> None: ''' 处理belle开源的知识增强数据集 ''' file_names = [ '/data/raw_data/bell_open_source/train_2M_CN.json', '/data/raw_data/bell_open_source/train_0.8M_CN.json', '/data/raw_data/bell_open_source/Belle_open_source_1M.json', ] save_file = PROJECT_ROOT + '/data/my_data/my_belll_3M_cn.parquet' # 后续append写入,存在文件先删除 if exists(save_file): assert delete_file(save_file) def process_function(line: str) -> dict: ''' 每行的处理函数 ''' item = ujson.loads(line) prompt = item['instruction'] response = item['output'] # 剔除翻译任务 if '翻译' in prompt or 'translate' in prompt.lower(): return None # 删除表格类任务 if '表格' in prompt or '-----' in prompt or '-----' in response: return None if len(response) < response_less_words: return None prompt = remove_duplicate_punctuation(prompt) response = remove_duplicate_punctuation(response) if len(response) < response_less_words: return None write_dict = { 'prompt': prompt, 'response': response } return write_dict for file in file_names: file = PROJECT_ROOT + file read_and_write_template(file, save_file, process_function) def convert_wiki_to_simple_zh(buffer_size: int=10000) -> None: ''' 将繁体wiki转换为简体Wiki ''' raw_zh_wiki_file = PROJECT_ROOT + '/data/raw_data/wiki.txt' save_zh_wiki_simple_file = PROJECT_ROOT + '/data/raw_data/wiki.simple.txt' if exists(save_zh_wiki_simple_file): assert delete_file(save_zh_wiki_simple_file) cc = OpenCC('t2s') cur_rows = [] append = cur_rows.append def procees_line(line: str) -> str: ''' 处理一行文本 ''' # 将繁体转换为简体 line = cc.convert(line) line = re.sub(r"\「|\」|\「|\」|\『|\』", '\"', line) # 将「」「」『』这些符号替换成引号 line = re.sub(r"\,\)|\;\)", ')', line) # 罗德·法尼(Rod Dodji Fanni,) line = re.sub(r"\(\,|\(\,", '(', line) # 阿魯拉·基馬(Alula Girma (, line = convert_en_punctuation_to_zh_punct(line) # 英文标点转换为中文标点 line = remove_duplicate_punctuation(line) # 删除中文空括号和重复的标点 return line with progress.open(raw_zh_wiki_file, 'r', encoding='utf-8') as read_f: with open(save_zh_wiki_simple_file, 'a', encoding='utf-8') as write_f: for line in read_f: line = procees_line(line) if len(line.strip()) == 0: continue line = '{}\n'.format(line) append(line) if len(cur_rows) >= buffer_size: write_f.writelines(cur_rows) cur_rows = [] append = cur_rows.append if len(cur_rows) > 0: write_f.writelines(cur_rows) cur_rows = [] def process_zh_wiki_data_to_datset(groups_cnt: int=10000, max_len: int=512, seed: int=23333) -> None: ''' 将Wiki中文数转换为问答数据集 wiki 下载地址:https://dumps.wikimedia.org/zhwiki/ 将下载的bz2文件转换为wiki.txt参考:https://github.com/apertium/WikiExtractor ''' raw_zh_wiki_file = PROJECT_ROOT + '/data/raw_data/wiki.txt' zhwiki_simple_file = PROJECT_ROOT + '/data/my_data/wiki_zh_simple.parquet' # 删除已经存在的数据 if exists(zhwiki_simple_file): assert delete_file(zhwiki_simple_file) # 将繁体转换为简体 cc = OpenCC('t2s') all_cnt, keep_cnt = 0, 0 # 构造问题的前缀 prompt_prefix = [ '什么是{}?', '介绍一下{}', '介绍一下什么是{}', '写一篇关于{}的介绍', '{}是什么?', '你知道{}吗?', '生成关于{}的介绍', '我想知道关于{}的详细信息', '你了解{}吗?', '请解释一下{}', '对于{},你有什么了解或看法吗?', '请告诉我关于{}的信息', '请简要描述一下{}', '请提供有关{}的一些详细信息', '能否解释一下{}是什么?', '请分享一些关于{}的背景知识', '请简要概括一下{}', '能给我一些关于{}的背景资料吗?', '有关{}的信息可以分享一下吗?', '你能告诉我{}是什么吗?', ] def procees_line(line: str) -> str: ''' 处理一行文本 ''' # 将繁体转换为简体 line = cc.convert(line) line = re.sub(r"\「|\」|\「|\」|\『|\』", '\"', line) # 将「」「」『』这些符号替换成引号 line = re.sub(r"\,\)|\;\)", ')', line) # 罗德·法尼(Rod Dodji Fanni,) line = re.sub(r"\(\,|\(\,", '(', line) # 阿魯拉·基馬(Alula Girma (, line = convert_en_punctuation_to_zh_punct(line) # 英文标点转换为中文标点 line = remove_duplicate_punctuation(line) # 删除中文空括号和重复的标点 return line np.random.seed(seed) choice = np.random.choice with progress.open(raw_zh_wiki_file, 'r', encoding='utf-8') as read_file: prompt = '' response = '' pre_line_len = 0 cur_rows = [] append = cur_rows.append for line in read_file: all_cnt += 1 # prompt已经保存,但是仍有多余的行,这些行使得response的长度>max_len,故跳过,不处理 if len(prompt) == 0 and pre_line_len > 0: pre_line_len = len(line.strip()) continue # 清洗一行 line = procees_line(line) # 确定问题,pre_line_len是0,既是上一行是空行,则当前行是新的百科词条,设置为prompt if prompt == '' and line.endswith(':') and pre_line_len == 0: prompt = choice(prompt_prefix).format(line[0: -1]) continue pre_line_len = len(line.strip()) # 问题下来若干行为答案 if prompt != '' and not line.endswith(':'): # 其实,pre_line_len已经是len(line.strip())了,如果len(line.strip())=0,既是当前行是0,则不管答案长度够不够,都需要保存了 if len(response) + len(line) <= max_len and pre_line_len != 0: response = '{}{}'.format(response, line) elif len(response) + len(line) > max_len or pre_line_len == 0: # 长度超了或者当前的百科已经结束,保存一条样例 keep_cnt += 1 response = '{}{}'.format(response, line) append({'prompt': prompt, 'response': ''.join(response[0: max_len])}) prompt = '' response = '' # =groups_cnt保存到文件 if len(cur_rows) >= groups_cnt: df = pd.DataFrame(cur_rows) write_single_parquet_file(zhwiki_simple_file, df) cur_rows = [] append = cur_rows.append # end for if len(prompt) > 0 and len(response) > 0: keep_cnt += 1 append({'prompt': prompt, 'response': response}) if len(cur_rows) > 0: df = pd.DataFrame(cur_rows) write_single_parquet_file(zhwiki_simple_file, df) cur_rows = [] log.info("merge into file: {}, 全部数据共{}行,清洗后剩余{}行".format(zhwiki_simple_file, all_cnt, keep_cnt), save_to_file=True) def merge_dataset_as_single_file(groups_cnt: int=50000, max_len: int=512, min_len: int=3, cut_max_len: bool=False) -> None: ''' 将多个数据集合并为一个数据集 ''' from_parquet_files = get_path_of_suffix_files(PROJECT_ROOT + '/data/my_data', '.parquet') save_file = PROJECT_ROOT + '/data/my_dataset.parquet' # 后续append写入,存在文件先删除 if exists(save_file): assert delete_file(save_file) cur_rows = [] append = cur_rows.append all_cnt, keep_cnt = 0, 0 for file in from_parquet_files: print('process file: {}'.format(file)) parquet_table = pq.read_table(file) for prompt, response in progress.track(zip(parquet_table['prompt'], parquet_table['response']), total=parquet_table.num_rows): prompt, response = prompt.as_py(), response.as_py() all_cnt += 1 if len(prompt) < min_len or len(response) < min_len: continue if cut_max_len and (len(prompt) > max_len or len(response) > max_len): prompt = prompt[0: max_len] response = response[0: max_len] keep_cnt += 1 append({'prompt': prompt , 'response': response}) if len(cur_rows) >= groups_cnt: df = pd.DataFrame(cur_rows) write_single_parquet_file(save_file, df) cur_rows = [] append = cur_rows.append # 处理末尾部分 if len(cur_rows) > 0: df = pd.DataFrame(cur_rows) write_single_parquet_file(save_file, df) cur_rows = [] log.info("merge into file: {}, 全部数据共{}行,清洗后剩余{}行".format(save_file, all_cnt, keep_cnt), save_to_file=True) def remove_dataset_duplicate_rows(groups_cnt: int=50000) -> None: ''' 使用mini_hash删除数据集中重复的部分 ''' from_parquet_files = PROJECT_ROOT + '/data/my_dataset.parquet' save_file = PROJECT_ROOT + '/data/my_dataset_no_dulpticates.parquet' # 后续append写入,存在文件先删除 if exists(save_file): assert delete_file(save_file) cur_rows = [] all_cnt, keep_cnt = 0, 0 row_index = -1 drop_dataset_duplicate = DropDatasetDuplicate(threshold=0.85, num_perm=256) parquet_table = pq.read_table(from_parquet_files) all_cnt = parquet_table.num_rows # 先顺序遍历获取哪些行是重复的 for prompt, response in progress.track(zip(parquet_table['prompt'], parquet_table['response']), total=parquet_table.num_rows): row_index += 1 doc = f"{prompt.as_py()}{response.as_py()}" drop_dataset_duplicate.add_doc(index=row_index, doc=doc) row_index = -1 need_to_drop_indexs = drop_dataset_duplicate.get_duplicate_indexs() # 再顺序遍历一遍,重复的行不添加到新的数据集 for prompt, response in progress.track(zip(parquet_table['prompt'], parquet_table['response']), total=parquet_table.num_rows): row_index += 1 # 不管有没有跳过行, row_index都必须+1 # 重复的行跳过 if row_index in need_to_drop_indexs: continue cur_rows.append({'prompt': prompt.as_py() , 'response': response.as_py()}) keep_cnt += 1 if len(cur_rows) >= groups_cnt: df = pd.DataFrame(cur_rows) write_single_parquet_file(save_file, df) cur_rows = [] # 处理末尾部分 if len(cur_rows) > 0: df = pd.DataFrame(cur_rows) write_single_parquet_file(save_file, df) log.info("merge into file: {}, 全部数据共{}行,文档去重后剩余{}行".format(save_file, all_cnt, keep_cnt), save_to_file=True) def shuffle_parquet_dataset(parquet_file: str, shuffle_file: str, seed: int=23333, groups_cnt: int=65536) -> None: ''' 打乱一个parquet文件数据集 ''' if not exists(parquet_file): raise Exception('can not find parquet file: {}'.format(parquet_file)) print('start shuffle...') pf = pq.read_table(parquet_file) df = pf.to_pandas() df = df.sample(frac=1.0, replace=False, random_state=seed, axis=0) if exists(shuffle_file): assert delete_file(shuffle_file) # 分块写入parquet,否则小内存读取直接OOM n = len(df) for i in range(0, n, groups_cnt): cur_group_df = df[i: i + groups_cnt] write_single_parquet_file(shuffle_file, cur_group_df) def count_my_json_data() -> None: ''' 统计目前的所有数据集数据量 ''' my_data_files = get_path_of_suffix_files(PROJECT_ROOT + '/data/my_data', '.json') result = [['file_name', 'count']] all_cnt = 0 for file in my_data_files: file_name = file.split('/')[-1] cur_cnt = 0 with progress.open(file, 'r', encoding='utf-8') as f: for _ in f: cur_cnt += 1 all_cnt += cur_cnt result.append([file_name, cur_cnt]) result.append(['汇总', all_cnt]) log.info(str(result), save_to_file=True) console = Console() table = Table(show_header=True, show_lines=True,) for col in result[0]: table.add_column(col) for i in range(1, len(result)): # 跳过表头 table.add_row(str(result[i][0]), str(result[i][1])) console.print(table) def count_my_parquet_data(parquet_file: str=None) -> None: ''' 统计dir目录下所有parquet数据集数据量 ''' my_data_files = [] if not parquet_file: my_data_files = get_path_of_suffix_files(PROJECT_ROOT + '/data/my_data', '.parquet') elif isdir(parquet_file): my_data_files = get_path_of_suffix_files(parquet_file, '.parquet') elif parquet_file.endswith('.parquet'): my_data_files = [parquet_file] result = [['file_name', 'count']] all_cnt = 0 for file in my_data_files: file_name = file.split('/')[-1] cur_cnt = 0 pf = ParquetFile(file) for pf_chunk in pf: cur_cnt += pf_chunk.info['rows'] all_cnt += cur_cnt result.append([file_name, cur_cnt]) result.append(['汇总', all_cnt]) log.info(str(result), save_to_file=True) console = Console() table = Table(show_header=True, show_lines=True,) for col in result[0]: table.add_column(col) for i in range(1, len(result)): # 跳过表头 table.add_row(str(result[i][0]), str(result[i][1])) console.print(table) def split_train_valid_test_datasets(source_parquet_file: str, max_len: int=320, seed: int=23333, train_ratio: float=0.91, test_ratio: float=0.0875, valid_ratio: float=0.0025, groups_cnt: int=50000) -> None: ''' 将原始数据拆分为训练集、测试集和验证集 ''' assert train_ratio + test_ratio + valid_ratio == 1.0 train_parquet_file = PROJECT_ROOT + '/data/my_train_dataset.parquet' test_parquet_file = PROJECT_ROOT + '/data/my_test_dataset.parquet' valid_parquet_file = PROJECT_ROOT + '/data/my_valid_dataset.parquet' if exists(train_parquet_file): assert delete_file(train_parquet_file) if exists(test_parquet_file): assert delete_file(test_parquet_file) if exists(valid_parquet_file): assert delete_file(valid_parquet_file) np.random.seed(seed) train, test, valid = [], [], [] parquet_table = pq.read_table(source_parquet_file) for prompt, response in progress.track(zip(parquet_table['prompt'], parquet_table['response']), total=parquet_table.num_rows): prompt, response = prompt.as_py(), response.as_py() rand = np.random.random() cur_data = {'prompt': ''.join(prompt[0: max_len]) , 'response': ''.join(response[0: max_len])} if 0 <= rand < train_ratio: train.append(cur_data) elif train_ratio <= rand < train_ratio + test_ratio: test.append(cur_data) else: valid.append(cur_data) if len(train) >= groups_cnt: write_single_parquet_file(train_parquet_file, pd.DataFrame(train)) train = [] if len(test) >= groups_cnt: write_single_parquet_file(test_parquet_file, pd.DataFrame(test)) test = [] if len(valid) >= groups_cnt: write_single_parquet_file(valid_parquet_file, pd.DataFrame(valid)) valid = [] if len(train) > 0: write_single_parquet_file(train_parquet_file, pd.DataFrame(train)) train = [] if len(test) > 0: write_single_parquet_file(test_parquet_file, pd.DataFrame(test)) test = [] if len(valid) > 0: write_single_parquet_file(valid_parquet_file, pd.DataFrame(valid)) valid = [] def parquet_to_text(sep='[SEP]', buffer_size: int=50000) -> None: ''' 将parquet文件转换为txt预料,句子之间用sep隔开 txt文件用于训练tokenizer,使用huggingface的BPE训练会导致OOM ''' parquet_file = PROJECT_ROOT + '/data/my_dataset.parquet' txt_file = PROJECT_ROOT + '/data/my_corpus.txt' if exists(txt_file): assert delete_file(txt_file) source_pf = ParquetFile(parquet_file) cur_rows = [] append = cur_rows.append with open(txt_file, 'a', encoding='utf-8') as f_write: for pf_chunk in progress.track(source_pf): for rows in pf_chunk.iter_row_groups(): for prompt, response in zip(rows['prompt'], rows['response']): append(prompt + sep + response + sep + '\n') if len(cur_rows) >= buffer_size: f_write.writelines(cur_rows) cur_rows = [] append = cur_rows.append # end for if len(cur_rows) > 0: f_write.writelines(cur_rows) cur_rows = [] def parquet_to_json() -> None: ''' 将parquet文件转换为json ''' parquet_file = PROJECT_ROOT + '/data/my_finetune_data_zh.parquet' json_file = PROJECT_ROOT + '/data/sft_train.json' if exists(json_file): assert delete_file(json_file) source_pf = ParquetFile(parquet_file) cur_rows = [] append = cur_rows.append for pf_chunk in progress.track(source_pf): for rows in pf_chunk.iter_row_groups(): for prompt, response in zip(rows['prompt'], rows['response']): if len(response) == 0 or len(prompt) == 0: continue append({ 'prompt': str(prompt), 'response': str(response), }) with open(json_file, 'w', encoding='utf-8') as f: ujson.dump(cur_rows, f, indent=4, ensure_ascii=False) def dataset_length_cnt() -> None: dataset_file = PROJECT_ROOT + '/data/my_dataset.shuffle.parquet' parquet_table = pq.read_table(dataset_file) que_len_dict, ans_len_dict = defaultdict(int), defaultdict(int) for prompt, response in progress.track(zip(parquet_table['prompt'], parquet_table['response']), total=parquet_table.num_rows): prompt, response = prompt.as_py(), response.as_py() que_len_dict[len(prompt)] += 1 ans_len_dict[len(response)] += 1 que_len, ans_len = [], [] for k, v in que_len_dict.items(): que_len.append([k, v]) for k, v in ans_len_dict.items(): ans_len.append([k, v]) def gather_gt_x(array: list[tuple], x: int=512) -> list: ''' 长度大于x的合并在一起 ''' new_array = [] gt_x_cnt = 0 for item in array: if item[0] < x: new_array.append([item[0], item[1]]) else: gt_x_cnt += item[1] new_array.append([x, gt_x_cnt]) return new_array max_len = 512 ans_list = gather_gt_x(ans_len, max_len) ans_list.sort(key=lambda x: x[0]) que_list = gather_gt_x(que_len, max_len) que_list.sort(key=lambda x: x[0]) ans_pd = pd.DataFrame(ans_list, columns=['length', 'count']) que_pd = pd.DataFrame(que_list, columns=['length', 'count']) def plot_sub_bar(plt, x, y, title: str, color: str='g') ->None: plt.bar(x, y, color=color, label='sample count') plt.ticklabel_format(style='sci',scilimits=(0,0), axis='y') plt.legend() plt.xlabel('length') plt.ylabel('count') plt.title(title) plt.figure(figsize=(10, 10),dpi=200) plt.subplot(2, 2, 1) plot_sub_bar(plt, que_pd['length'], que_pd['count'], title='prompt length', color='c') plt.subplot(2, 2, 2) plot_sub_bar(plt, ans_pd['length'], ans_pd['count'], title='response length', color='g') le512_pd = ans_pd[ans_pd['length'] < 512] plt.subplot(2, 2, 3) plot_sub_bar(plt, le512_pd['length'], le512_pd['count'], title='response length < 512', color='limegreen') le320_pd = ans_pd[ans_pd['length'] < 320] plt.subplot(2, 2, 4) plot_sub_bar(plt, le320_pd['length'], le320_pd['count'], title='response length < 320', color='limegreen') plt.savefig(PROJECT_ROOT + '/img/sentence_length.png') plt.show() def process_belle_knowledge_enhanced_dataset_for_finetune(max_len: int=320, group_cnt: int=50000) -> None: ''' 处理belle开源的知识增强数据集 ''' file_names = [ '/data/raw_data/bell_open_source/Belle_open_source_0.5M.json', '/data/raw_data/bell_open_source/train_conv_2.json', '/data/raw_data/bell_open_source/generated_chat_0.4M.json', ] save_file = PROJECT_ROOT + '/data/my_finetune_data_zh.parquet' # 后续append写入,存在文件先删除 if exists(save_file): assert delete_file(save_file) def process_function(line: str) -> dict: ''' 每行的处理函数 ''' item = ujson.loads(line) prompt = item['instruction'] response = item['output'] # 剔除翻译任务 if 'translate' in prompt.lower(): return None for word in ('翻译', '英译', '译英', '中译', '译中', '汉译', '译汉'): if word in prompt: return None # 删除表格类任务 if '表格' in prompt or '-----' in prompt or '-----' in response: return None if len(prompt) > max_len or len(response) > max_len: return None write_dict = { 'prompt': prompt, 'response': response } return write_dict for file in file_names: file = PROJECT_ROOT + file read_and_write_template(file, save_file, process_function) if __name__ == '__main__': processed_file_dir = PROJECT_ROOT + '/data/my_data' if not exists(processed_file_dir): mkdir(processed_file_dir) # 注释了,不重复处理 # 1. # process_web_text(keep_start=5, response_less_word=15) # 2. # process_bake_qa(response_less_word=15) # 3. # process_chinese_medical_datasets(response_less_word=15) # 4. 金融问答数据集质量太差了 # process_finace_dataset(prompt_less_word=10, response_less_word=15) # 5. # process_zhihu_kol_dataset(prompt_less_word=4, response_less_word=10) # 6. # process_belle_knowledge_enhanced_dataset(response_less_words=5) # convert_wiki_to_simple_zh() # 7. # process_zh_wiki_data_to_datset(groups_cnt=10000, max_len=512) #================================================================= # merge # merge_dataset_as_single_file(groups_cnt=50000, min_len=3, max_len=512, cut_max_len=True) remove_dataset_duplicate_rows(groups_cnt=50000) # # shuffle # shuffle_parquet_dataset( # parquet_file=PROJECT_ROOT + '/data/my_dataset.parquet', # shuffle_file=PROJECT_ROOT + '/data/my_dataset.shuffle.parquet', # seed=23333 # ) # split train validated and test # split_train_valid_test_datasets( # source_parquet_file=PROJECT_ROOT + '/data/my_dataset.shuffle.parquet', # max_len=320, # groups_cnt=50000 # ) # parquet_to_text() # count_my_parquet_data(PROJECT_ROOT + '/data/my_dataset.parquet') # dataset_length_cnt() # process_belle_knowledge_enhanced_dataset_for_finetune(max_len=320, group_cnt=50000) # count_my_parquet_data(PROJECT_ROOT + '/data/') parquet_to_json() # count_my_json_data()