from tokenizers import BertWordPieceTokenizer from transformers import BertTokenizer from transformers import BertTokenizerFast import argparse import pandas as pd import pickle import jieba.analyse from tqdm import tqdm from transformers import GPT2TokenizerFast, GPT2LMHeadModel import logging import numpy as np from chatbot.config import config def create_logger(log_path): """ 将日志输出到日志文件和控制台 """ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)s - %(levelname)s - %(message)s') # 创建一个handler,用于写入日志文件 file_handler = logging.FileHandler( filename=log_path) file_handler.setFormatter(formatter) file_handler.setLevel(logging.INFO) logger.addHandler(file_handler) # 创建一个handler,用于将日志输出到控制台 console = logging.StreamHandler() console.setLevel(logging.DEBUG) console.setFormatter(formatter) logger.addHandler(console) return logger def preprocess(): """ 对原始语料进行tokenize,将每段对话处理成如下形式:"[CLS]utterance1[SEP]utterance2[SEP]utterance3[SEP]" """ # 设置参数 parser = argparse.ArgumentParser() parser.add_argument('--vocab_path', default='vocab/vocab.txt', type=str, required=False, help='词表路径') parser.add_argument('--log_path', default='data/preprocess.log', type=str, required=False, help='训练日志存放位置') parser.add_argument('--train_path', default='data/train.txt', type=str, required=False, help='训练日志存放位置') parser.add_argument('--save_path', default='data/train.pkl', type=str, required=False, help='tokenize的训练数据集') args = parser.parse_args() # 初始化日志对象 logger = create_logger(args.log_path) # 初始化tokenizer tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]") special_tokens = [] for key in config["mask_token"].keys(): special_tokens.append(key) tokenizer.add_special_tokens( {'additional_special_tokens':special_tokens} ) sep_id = tokenizer.sep_token_id cls_id = tokenizer.cls_token_id logger.info("preprocessing data,data path:{}, save path:{}".format(args.train_path, args.save_path)) # 读取训练数据集 with open(args.train_path, 'rb') as f: data = f.read().decode("utf-8") # 需要区分linux和windows环境下的换行符 if "\r\n" in data: train_data = data.split("\r\n\r\n") else: train_data = data.split("\n\n") logger.info("there are {} dialogue in dataset".format(len(train_data))) # 开始进行tokenize # 保存所有的对话数据,每条数据的格式为:"[CLS]utterance1[SEP]utterance2[SEP]utterance3[SEP]" dialogue_len = [] # 记录所有对话tokenize之后的长度,用于统计中位数与均值 dialogue_list = [] with open(args.save_path, "w", encoding="utf-8") as f: for index, dialogue in enumerate(tqdm(train_data)): if "\r\n" in data: utterances = dialogue.split("\r\n") else: utterances = dialogue.split("\n") input_ids = [cls_id] # 每个dialogue以[CLS]开头 for utterance in utterances: input_ids += tokenizer.encode(utterance, add_special_tokens=False) input_ids.append(sep_id) # 每个utterance之后添加[SEP],表示utterance结束 dialogue_len.append(len(input_ids)) dialogue_list.append(input_ids) len_mean = np.mean(dialogue_len) len_median = np.median(dialogue_len) len_max = np.max(dialogue_len) with open(args.save_path, "wb") as f: pickle.dump(dialogue_list, f) logger.info("finish preprocessing data,the result is stored in {}".format(args.save_path)) logger.info("mean of dialogue len:{},median of dialogue len:{},max len:{}".format(len_mean, len_median, len_max)) if __name__ == '__main__': preprocess()