Gosse Minnema
Add sociofillmore code, load dataset via private dataset repo
b11ac48
raw
history blame
7.49 kB
"""
Adapted from comm2multilabel.py from the Bert-for-FrameNet project (https://gitlab.com/gosseminnema/bert-for-framenet)
"""
import dataclasses
import json
import os
import glob
import sys
from collections import defaultdict
from typing import List, Optional
import nltk
from concrete import Communication
from concrete.util import read_communication_from_file, lun, get_tokens
@dataclasses.dataclass
class FrameAnnotation:
tokens: List[str] = dataclasses.field(default_factory=list)
pos: List[str] = dataclasses.field(default_factory=list)
@dataclasses.dataclass
class MultiLabelAnnotation(FrameAnnotation):
frame_list: List[List[str]] = dataclasses.field(default_factory=list)
lu_list: List[Optional[str]] = dataclasses.field(default_factory=list)
def to_txt(self):
for i, tok in enumerate(self.tokens):
yield f"{tok} {self.pos[i]} {'|'.join(self.frame_list[i]) or '_'} {self.lu_list[i] or '_'}"
@staticmethod
def from_txt(sentence_lines):
tokens = []
pos = []
frame_list = []
lu_list = []
for line in sentence_lines:
# ignore any spaces
if line.startswith(" "):
continue
columns = line.split()
tokens.append(columns[0])
pos.append(columns[1])
# read frame list, handle empty lists
if columns[2] == "_":
frame_list.append([])
else:
frame_list.append(columns[2].split("|"))
# read lu list, handle nulls
if columns[3] == "_":
lu_list.append(None)
else:
lu_list.append(columns[3])
return MultiLabelAnnotation(tokens, pos, frame_list, lu_list)
def get_label_set(self):
label_set = set()
for tok_labels in self.frame_list:
for label in tok_labels:
label_set.add(label)
return label_set
def convert_file(file, language="english", confidence_filter=0.0):
print("Reading input file...")
comm = read_communication_from_file(file)
print("Mapping sentences to situations...")
tok_uuid_to_situation = map_sent_to_situation(comm)
print("# sentences with situations:", len(tok_uuid_to_situation))
for section in lun(comm.sectionList):
for sentence in lun(section.sentenceList):
tokens = get_tokens(sentence.tokenization)
situations = tok_uuid_to_situation[sentence.tokenization.uuid.uuidString]
tok_to_annos = map_tokens_to_annotations(comm, situations, confidence_filter)
frame_list, tok_list = prepare_ml_lists(language, tok_to_annos, tokens)
ml_anno = MultiLabelAnnotation(tok_list, ["_" for _ in tok_list], frame_list,
[None for _ in tok_list])
yield ml_anno
def prepare_ml_lists(language, tok_to_annos, tokens):
tok_list = []
frame_list = []
for tok_idx, tok in enumerate(tokens):
# split tokens that include punctuation
split_tok = nltk.word_tokenize(tok.text, language=language)
tok_list.extend(split_tok)
tok_anno = []
for anno in tok_to_annos.get(tok_idx, []):
tok_anno.append(anno)
frame_list.extend([list(tok_anno) for _ in split_tok])
# remove annotations from final punctuation & solve BIO weird stuff
for idx, (tok, frame_annos) in enumerate(zip(tok_list, frame_list)):
if tok in ",.:;\"'`«»":
to_delete = []
for fa in frame_annos:
if fa.startswith("T:"):
compare_fa = fa
else:
compare_fa = "I" + fa[1:]
if idx == len(tok_list) - 1:
to_delete.append(fa)
elif compare_fa not in frame_list[idx + 1]:
to_delete.append(fa)
for fa in to_delete:
frame_annos.remove(fa)
for fa_idx, fa in enumerate(frame_annos):
if fa.startswith("B:"):
# check if we had exactly the same label the token before
if idx > 0 and fa in frame_list[idx - 1]:
frame_annos[fa_idx] = "I" + fa[1:]
return frame_list, tok_list
def map_tokens_to_annotations(comm: Communication, situations: List[str], confidence_filter: float):
tok_to_annos = defaultdict(list)
for sit_idx, sit_uuid in enumerate(situations):
situation = comm.situationMentionForUUID[sit_uuid]
if situation.confidence < confidence_filter:
continue
frame_type = situation.situationKind
tgt_tokens = situation.tokens.tokenIndexList
if frame_type == "@@VIRTUAL_ROOT@@":
continue
for tok_id in tgt_tokens:
tok_to_annos[tok_id].append(f"T:{frame_type}@{sit_idx:02}@@{situation.confidence}")
for arg in situation.argumentList:
if arg.confidence < confidence_filter:
continue
fe_type = arg.role
fe_tokens = arg.entityMention.tokens.tokenIndexList
for tok_n, tok_id in enumerate(fe_tokens):
if tok_n == 0:
bio = "B"
else:
bio = "I"
tok_to_annos[tok_id].append(f"{bio}:{frame_type}:{fe_type}@{sit_idx:02}@@{arg.confidence}")
return tok_to_annos
def map_sent_to_situation(comm):
tok_uuid_to_situation = defaultdict(list)
for situation in comm.situationMentionSetList:
for mention in situation.mentionList:
tok_uuid_to_situation[mention.tokens.tokenizationId.uuidString].append(mention.uuid.uuidString)
return tok_uuid_to_situation
def main():
file_in = sys.argv[1]
language = sys.argv[2]
output_directory = sys.argv[3]
confidence_filter = float(sys.argv[4])
split_by_migration_files = False
file_in_base = os.path.basename(file_in)
file_out = f"{output_directory}/lome_{file_in_base}"
multi_label_annos = list(convert_file(file_in, language=language, confidence_filter=confidence_filter))
multi_label_json = [dataclasses.asdict(anno) for anno in multi_label_annos]
if split_by_migration_files:
files = glob.glob("output/migration/split_data/split_dev10_sep_txt_files/*.orig.txt")
files.sort(key=lambda f: int(f.split("/")[-1].rstrip(".orig.txt")))
for anno, file in zip(multi_label_annos, files):
basename = file.split("/")[-1].rstrip(".orig.txt")
spl_file_out = f"{output_directory}/{basename}"
with open(f"{spl_file_out}.txt", "w", encoding="utf-8") as f_txt:
for line in anno.to_txt():
f_txt.write(line + os.linesep)
f_txt.write(os.linesep)
else:
print(file_out)
with open(f"{file_out}.json", "w", encoding="utf-8") as f_json:
json.dump(multi_label_json, f_json, indent=4)
with open(f"{file_out}.txt", "w", encoding="utf-8") as f_txt:
for anno in multi_label_annos:
for line in anno.to_txt():
f_txt.write(line + os.linesep)
f_txt.write(os.linesep)
if __name__ == '__main__':
main()