marriage_law_retrieval / preprocessors.py
luciusssss's picture
Upload 22 files
a48216a verified
import os
import json
import pickle as pkl
import numpy as np
from sklearn.preprocessing import MultiLabelBinarizer
class BasicPreprocessor(object):
def __init__(self, data_generator, tokenizer, args):
self.data_generator = data_generator
self.tokenizer = tokenizer
self.args = args
file_path = os.path.join(args.data_dir, args.data_file)
if file_path.endswith("pkl"):
with open(file_path, "rb") as f:
self.raw_data = pkl.load(f)
print(self.raw_data[0])
exit()
elif file_path.endswith("json"):
self.raw_data = json.load(open(file_path, "r", encoding="utf-8"))
self.shuffle()
self.mlb=MultiLabelBinarizer()
self.mlb.fit([args.labels])
def shuffle(self):
idx=np.arange(len(self.raw_data))
np.random.shuffle(idx)
self.raw_data=np.array(self.raw_data)[idx]
def process(self):
args = self.args
data_generator = self.data_generator
raw_data = self.raw_data
tokenizer = self.tokenizer
mlb = self.mlb
if args.test_only:
train_data = data_generator(raw_data[:1], tokenizer, mlb, 'test', args)
test_data = data_generator(raw_data, tokenizer, mlb, 'test', args)
return train_data, test_data
#只使用90%作为训练集,10%作为测试集,不使用验证集
train_data = data_generator(raw_data[:int(len(raw_data)*0.9)], tokenizer, mlb, 'train', args)
test_data = data_generator(raw_data[int(len(raw_data)*0.9):], tokenizer, mlb, 'test', args)
return train_data, test_data