|
import os |
|
import json |
|
import pickle as pkl |
|
import numpy as np |
|
from sklearn.preprocessing import MultiLabelBinarizer |
|
|
|
class BasicPreprocessor(object): |
|
def __init__(self, data_generator, tokenizer, args): |
|
self.data_generator = data_generator |
|
self.tokenizer = tokenizer |
|
self.args = args |
|
file_path = os.path.join(args.data_dir, args.data_file) |
|
if file_path.endswith("pkl"): |
|
with open(file_path, "rb") as f: |
|
self.raw_data = pkl.load(f) |
|
print(self.raw_data[0]) |
|
exit() |
|
elif file_path.endswith("json"): |
|
self.raw_data = json.load(open(file_path, "r", encoding="utf-8")) |
|
self.shuffle() |
|
|
|
self.mlb=MultiLabelBinarizer() |
|
self.mlb.fit([args.labels]) |
|
|
|
def shuffle(self): |
|
idx=np.arange(len(self.raw_data)) |
|
np.random.shuffle(idx) |
|
self.raw_data=np.array(self.raw_data)[idx] |
|
|
|
def process(self): |
|
args = self.args |
|
data_generator = self.data_generator |
|
raw_data = self.raw_data |
|
tokenizer = self.tokenizer |
|
mlb = self.mlb |
|
|
|
if args.test_only: |
|
train_data = data_generator(raw_data[:1], tokenizer, mlb, 'test', args) |
|
test_data = data_generator(raw_data, tokenizer, mlb, 'test', args) |
|
return train_data, test_data |
|
|
|
train_data = data_generator(raw_data[:int(len(raw_data)*0.9)], tokenizer, mlb, 'train', args) |
|
test_data = data_generator(raw_data[int(len(raw_data)*0.9):], tokenizer, mlb, 'test', args) |
|
|
|
return train_data, test_data |
|
|
|
|
|
|
|
|