import os import json import pickle as pkl import numpy as np from sklearn.preprocessing import MultiLabelBinarizer class BasicPreprocessor(object): def __init__(self, data_generator, tokenizer, args): self.data_generator = data_generator self.tokenizer = tokenizer self.args = args file_path = os.path.join(args.data_dir, args.data_file) if file_path.endswith("pkl"): with open(file_path, "rb") as f: self.raw_data = pkl.load(f) print(self.raw_data[0]) exit() elif file_path.endswith("json"): self.raw_data = json.load(open(file_path, "r", encoding="utf-8")) self.shuffle() self.mlb=MultiLabelBinarizer() self.mlb.fit([args.labels]) def shuffle(self): idx=np.arange(len(self.raw_data)) np.random.shuffle(idx) self.raw_data=np.array(self.raw_data)[idx] def process(self): args = self.args data_generator = self.data_generator raw_data = self.raw_data tokenizer = self.tokenizer mlb = self.mlb if args.test_only: train_data = data_generator(raw_data[:1], tokenizer, mlb, 'test', args) test_data = data_generator(raw_data, tokenizer, mlb, 'test', args) return train_data, test_data #只使用90%作为训练集,10%作为测试集,不使用验证集 train_data = data_generator(raw_data[:int(len(raw_data)*0.9)], tokenizer, mlb, 'train', args) test_data = data_generator(raw_data[int(len(raw_data)*0.9):], tokenizer, mlb, 'test', args) return train_data, test_data