Spaces:
Build error
Build error
""" | |
Adapted from comm2multilabel.py from the Bert-for-FrameNet project (https://gitlab.com/gosseminnema/bert-for-framenet) | |
""" | |
import dataclasses | |
import json | |
import os | |
import glob | |
import sys | |
from collections import defaultdict | |
from typing import List, Optional | |
import nltk | |
from concrete import Communication | |
from concrete.util import read_communication_from_file, lun, get_tokens | |
class FrameAnnotation: | |
tokens: List[str] = dataclasses.field(default_factory=list) | |
pos: List[str] = dataclasses.field(default_factory=list) | |
class MultiLabelAnnotation(FrameAnnotation): | |
frame_list: List[List[str]] = dataclasses.field(default_factory=list) | |
lu_list: List[Optional[str]] = dataclasses.field(default_factory=list) | |
def to_txt(self): | |
for i, tok in enumerate(self.tokens): | |
yield f"{tok} {self.pos[i]} {'|'.join(self.frame_list[i]) or '_'} {self.lu_list[i] or '_'}" | |
def from_txt(sentence_lines): | |
tokens = [] | |
pos = [] | |
frame_list = [] | |
lu_list = [] | |
for line in sentence_lines: | |
# ignore any spaces | |
if line.startswith(" "): | |
continue | |
columns = line.split() | |
tokens.append(columns[0]) | |
pos.append(columns[1]) | |
# read frame list, handle empty lists | |
if columns[2] == "_": | |
frame_list.append([]) | |
else: | |
frame_list.append(columns[2].split("|")) | |
# read lu list, handle nulls | |
if columns[3] == "_": | |
lu_list.append(None) | |
else: | |
lu_list.append(columns[3]) | |
return MultiLabelAnnotation(tokens, pos, frame_list, lu_list) | |
def get_label_set(self): | |
label_set = set() | |
for tok_labels in self.frame_list: | |
for label in tok_labels: | |
label_set.add(label) | |
return label_set | |
def convert_file(file, language="english", confidence_filter=0.0): | |
print("Reading input file...") | |
comm = read_communication_from_file(file) | |
print("Mapping sentences to situations...") | |
tok_uuid_to_situation = map_sent_to_situation(comm) | |
print("# sentences with situations:", len(tok_uuid_to_situation)) | |
for section in lun(comm.sectionList): | |
for sentence in lun(section.sentenceList): | |
tokens = get_tokens(sentence.tokenization) | |
situations = tok_uuid_to_situation[sentence.tokenization.uuid.uuidString] | |
tok_to_annos = map_tokens_to_annotations(comm, situations, confidence_filter) | |
frame_list, tok_list = prepare_ml_lists(language, tok_to_annos, tokens) | |
ml_anno = MultiLabelAnnotation(tok_list, ["_" for _ in tok_list], frame_list, | |
[None for _ in tok_list]) | |
yield ml_anno | |
def prepare_ml_lists(language, tok_to_annos, tokens): | |
tok_list = [] | |
frame_list = [] | |
for tok_idx, tok in enumerate(tokens): | |
# split tokens that include punctuation | |
split_tok = nltk.word_tokenize(tok.text, language=language) | |
tok_list.extend(split_tok) | |
tok_anno = [] | |
for anno in tok_to_annos.get(tok_idx, []): | |
tok_anno.append(anno) | |
frame_list.extend([list(tok_anno) for _ in split_tok]) | |
# remove annotations from final punctuation & solve BIO weird stuff | |
for idx, (tok, frame_annos) in enumerate(zip(tok_list, frame_list)): | |
if tok in ",.:;\"'`«»": | |
to_delete = [] | |
for fa in frame_annos: | |
if fa.startswith("T:"): | |
compare_fa = fa | |
else: | |
compare_fa = "I" + fa[1:] | |
if idx == len(tok_list) - 1: | |
to_delete.append(fa) | |
elif compare_fa not in frame_list[idx + 1]: | |
to_delete.append(fa) | |
for fa in to_delete: | |
frame_annos.remove(fa) | |
for fa_idx, fa in enumerate(frame_annos): | |
if fa.startswith("B:"): | |
# check if we had exactly the same label the token before | |
if idx > 0 and fa in frame_list[idx - 1]: | |
frame_annos[fa_idx] = "I" + fa[1:] | |
return frame_list, tok_list | |
def map_tokens_to_annotations(comm: Communication, situations: List[str], confidence_filter: float): | |
tok_to_annos = defaultdict(list) | |
for sit_idx, sit_uuid in enumerate(situations): | |
situation = comm.situationMentionForUUID[sit_uuid] | |
if situation.confidence < confidence_filter: | |
continue | |
frame_type = situation.situationKind | |
tgt_tokens = situation.tokens.tokenIndexList | |
if frame_type == "@@VIRTUAL_ROOT@@": | |
continue | |
for tok_id in tgt_tokens: | |
tok_to_annos[tok_id].append(f"T:{frame_type}@{sit_idx:02}@@{situation.confidence}") | |
for arg in situation.argumentList: | |
if arg.confidence < confidence_filter: | |
continue | |
fe_type = arg.role | |
fe_tokens = arg.entityMention.tokens.tokenIndexList | |
for tok_n, tok_id in enumerate(fe_tokens): | |
if tok_n == 0: | |
bio = "B" | |
else: | |
bio = "I" | |
tok_to_annos[tok_id].append(f"{bio}:{frame_type}:{fe_type}@{sit_idx:02}@@{arg.confidence}") | |
return tok_to_annos | |
def map_sent_to_situation(comm): | |
tok_uuid_to_situation = defaultdict(list) | |
for situation in comm.situationMentionSetList: | |
for mention in situation.mentionList: | |
tok_uuid_to_situation[mention.tokens.tokenizationId.uuidString].append(mention.uuid.uuidString) | |
return tok_uuid_to_situation | |
def main(): | |
file_in = sys.argv[1] | |
language = sys.argv[2] | |
output_directory = sys.argv[3] | |
confidence_filter = float(sys.argv[4]) | |
split_by_migration_files = False | |
file_in_base = os.path.basename(file_in) | |
file_out = f"{output_directory}/lome_{file_in_base}" | |
multi_label_annos = list(convert_file(file_in, language=language, confidence_filter=confidence_filter)) | |
multi_label_json = [dataclasses.asdict(anno) for anno in multi_label_annos] | |
if split_by_migration_files: | |
files = glob.glob("output/migration/split_data/split_dev10_sep_txt_files/*.orig.txt") | |
files.sort(key=lambda f: int(f.split("/")[-1].rstrip(".orig.txt"))) | |
for anno, file in zip(multi_label_annos, files): | |
basename = file.split("/")[-1].rstrip(".orig.txt") | |
spl_file_out = f"{output_directory}/{basename}" | |
with open(f"{spl_file_out}.txt", "w", encoding="utf-8") as f_txt: | |
for line in anno.to_txt(): | |
f_txt.write(line + os.linesep) | |
f_txt.write(os.linesep) | |
else: | |
print(file_out) | |
with open(f"{file_out}.json", "w", encoding="utf-8") as f_json: | |
json.dump(multi_label_json, f_json, indent=4) | |
with open(f"{file_out}.txt", "w", encoding="utf-8") as f_txt: | |
for anno in multi_label_annos: | |
for line in anno.to_txt(): | |
f_txt.write(line + os.linesep) | |
f_txt.write(os.linesep) | |
if __name__ == '__main__': | |
main() | |