|
from argparse import ArgumentParser |
|
import hashlib |
|
import os |
|
|
|
from sftp.data_reader import BetterDatasetReader, ConcreteDatasetReader |
|
from tools.ontology_mapping.force_map import ontology_map, read_framenet |
|
|
|
|
|
def read_ace_better(reader, data_path): |
|
sentences = list() |
|
for ins in reader.read(data_path): |
|
sentences.append(tuple(ins.fields['raw_inputs'].metadata[key] for key in ['sentence', 'spans'])) |
|
return sentences |
|
|
|
|
|
def run(model_path, src_data_path, tgt_data_path, device, dst_path): |
|
if model_path.endswith('.tar.gz'): |
|
model_md5 = hashlib.md5(open(model_path, 'rb').read()).hexdigest() |
|
else: |
|
model_md5 = hashlib.md5(open(os.path.join(model_path, 'model.tar.gz'), 'rb').read()).hexdigest() |
|
print('model md5: ', model_md5) |
|
if 'better' in tgt_data_path.lower(): |
|
reader = BetterDatasetReader(eval_type='basic', pretrained_model='roberta-large', ignore_label=False) |
|
elif 'ace' in tgt_data_path.lower(): |
|
reader = ConcreteDatasetReader(ignore_unlabeled_sentence=True, pretrained_model='roberta-large') |
|
else: |
|
raise NotImplementedError |
|
meta = { |
|
'model': {'path': model_path, 'md5': model_md5}, |
|
'src_data_path': src_data_path, |
|
'tgt_data_path': tgt_data_path |
|
} |
|
|
|
|
|
src_data, tgt_data = read_framenet(src_data_path), read_ace_better(reader, tgt_data_path) |
|
ontology_map(model_path, src_data, tgt_data, device, dst_path, meta) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = ArgumentParser() |
|
parser.add_argument('model', metavar='MODEL_PATH') |
|
parser.add_argument('src', metavar='SRC_DATA_PATH') |
|
parser.add_argument('tgt', metavar='TGT_DATA_PATH') |
|
parser.add_argument('dst', metavar='DESTINATION_PATH') |
|
parser.add_argument('-d', type=int, help='device', default=-1) |
|
cmd_args = parser.parse_args() |
|
run(cmd_args.model, cmd_args.src, cmd_args.tgt, cmd_args.d, cmd_args.dst) |
|
|