from sftp import SpanPredictor def print_children(sentence, boundary, labels, _): print('Sentence:', ' '.join(sentence)) for (start_idx, end_idx), lbl in zip(boundary, labels): print(' '.join(sentence[start_idx:end_idx+1]), ':', lbl) print('='*20) def example(): print("Loading predictor...") predictor = SpanPredictor.from_path( #'/home/gqin2/public/release/sftp/0.0.2/framenet', "/data/p289731/cloned/lome-models/models/spanfinder/model.mod.tar.gz", cuda_device=-1 ) print("Predicting for sentence..") sentence = ['Tom', 'eats', 'an', 'apple', 'and', 'he', 'wakes', 'up', '.'] p1 = predictor.force_decode(sentence) print_children(sentence, *p1) p2 = predictor.force_decode(sentence, parent_span=(1, 1), parent_label='Ingestion') print_children(sentence, *p2) p3 = predictor.force_decode(sentence, child_spans=[(0, 0), (2, 3)], parent_span=(1, 1), parent_label='Ingestion') print_children(sentence, *p3) if __name__ == '__main__': example()