Spaces:
Sleeping
Sleeping
import os | |
import argparse | |
from xml.etree import ElementTree | |
import copy | |
from operator import attrgetter | |
import json | |
import logging | |
from sftp import SpanPredictor | |
def predict_kairos(model_archive, source_folder, onto_map): | |
xml_files = list() | |
for root, _, files in os.walk(source_folder): | |
for f in files: | |
if f.endswith('.xml'): | |
xml_files.append(os.path.join(root, f)) | |
logging.info(f'{len(xml_files)} files are found:') | |
for fn in xml_files: | |
logging.info(' - ' + fn) | |
logging.info('Loading ontology from ' + onto_map) | |
k_map = dict() | |
for kairos_event, content in json.load(open(onto_map)).items(): | |
for fr in content['framenet']: | |
if fr['label'] in k_map: | |
logging.info("Duplicate frame: " + fr['label']) | |
k_map[fr['label']] = kairos_event | |
logging.info('Loading model from ' + model_archive + ' ...') | |
predictor = SpanPredictor.from_path(model_archive) | |
predictions = list() | |
for fn in xml_files: | |
logging.info('Now processing ' + os.path.basename(fn)) | |
tree = ElementTree.parse(fn).getroot() | |
for doc in tree: | |
doc_meta = copy.deepcopy(doc.attrib) | |
text = list(doc)[0] | |
for seg in text: | |
seg_meta = copy.deepcopy(doc_meta) | |
seg_meta['seg'] = copy.deepcopy(seg.attrib) | |
tokens = [child for child in seg if child.tag == 'TOKEN'] | |
tokens.sort(key=lambda t: t.attrib['start_char']) | |
words = list(map(attrgetter('text'), tokens)) | |
one_pred = predictor.predict_sentence(words) | |
one_pred['meta'] = seg_meta | |
new_frames = list() | |
for fr in one_pred['prediction']: | |
if fr['label'] in k_map: | |
fr['label'] = k_map[fr['label']] | |
new_frames.append(fr) | |
one_pred['prediction'] = new_frames | |
predictions.append(one_pred) | |
logging.info('Finished Prediction.') | |
return predictions | |
def do_task(input_dir, model_archive, onto_map): | |
""" | |
This function is called by the KAIROS infrastructure code for each | |
TASK1 input. | |
""" | |
return predict_kairos(model_archive=model_archive, | |
source_folder=input_dir, | |
onto_map=onto_map) | |
def run(): | |
parser = argparse.ArgumentParser(description='Span Finder for KAIROS Quizlet4\n') | |
parser.add_argument('model_archive', metavar='MODEL_ARCHIVE', type=str, help='Path to model archive file.') | |
parser.add_argument('source_folder', metavar='SOURCE_FOLDER', type=str, help='Path to the folder that contains the XMLs.') | |
parser.add_argument('onto_map', metavar='ONTO_MAP', type=str, help='Path to the ontology JSON.') | |
parser.add_argument('destination', metavar='DESTINATION', type=str, help='Output path. (jsonl file path)') | |
args = parser.parse_args() | |
logging.basicConfig(level='INFO', format="%(asctime)s %(name)-12s %(levelname)-8s %(message)s") | |
predictions = predict_kairos(model_archive=args.model_archive, | |
source_folder=args.source_folder, | |
onto_map=args.onto_map) | |
logging.info('Saving to ' + args.destination + ' ...') | |
os.makedirs(os.path.dirname(args.destination), exist_ok=True) | |
with open(args.destination, 'w') as fp: | |
fp.write('\n'.join(map(json.dumps, predictions))) | |
logging.info('Done.') | |
if __name__ == '__main__': | |
run() | |