File size: 7,485 Bytes
b11ac48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
"""

Adapted from comm2multilabel.py from the Bert-for-FrameNet project (https://gitlab.com/gosseminnema/bert-for-framenet)

"""

import dataclasses
import json
import os
import glob
import sys
from collections import defaultdict
from typing import List, Optional

import nltk
from concrete import Communication
from concrete.util import read_communication_from_file, lun, get_tokens


@dataclasses.dataclass
class FrameAnnotation:
    tokens: List[str] = dataclasses.field(default_factory=list)
    pos: List[str] = dataclasses.field(default_factory=list)


@dataclasses.dataclass
class MultiLabelAnnotation(FrameAnnotation):
    frame_list: List[List[str]] = dataclasses.field(default_factory=list)
    lu_list: List[Optional[str]] = dataclasses.field(default_factory=list)

    def to_txt(self):
        for i, tok in enumerate(self.tokens):
            yield f"{tok} {self.pos[i]} {'|'.join(self.frame_list[i]) or '_'} {self.lu_list[i] or '_'}"

    @staticmethod
    def from_txt(sentence_lines):

        tokens = []
        pos = []
        frame_list = []
        lu_list = []
        for line in sentence_lines:

            # ignore any spaces
            if line.startswith(" "):
                continue

            columns = line.split()
            tokens.append(columns[0])
            pos.append(columns[1])

            # read frame list, handle empty lists
            if columns[2] == "_":
                frame_list.append([])
            else:
                frame_list.append(columns[2].split("|"))

            # read lu list, handle nulls
            if columns[3] == "_":
                lu_list.append(None)
            else:
                lu_list.append(columns[3])
        return MultiLabelAnnotation(tokens, pos, frame_list, lu_list)

    def get_label_set(self):
        label_set = set()
        for tok_labels in self.frame_list:
            for label in tok_labels:
                label_set.add(label)
        return label_set


def convert_file(file, language="english", confidence_filter=0.0):
    print("Reading input file...")
    comm = read_communication_from_file(file)

    print("Mapping sentences to situations...")
    tok_uuid_to_situation = map_sent_to_situation(comm)

    print("# sentences with situations:", len(tok_uuid_to_situation))

    for section in lun(comm.sectionList):
        for sentence in lun(section.sentenceList):
            tokens = get_tokens(sentence.tokenization)
            situations = tok_uuid_to_situation[sentence.tokenization.uuid.uuidString]
            tok_to_annos = map_tokens_to_annotations(comm, situations, confidence_filter)

            frame_list, tok_list = prepare_ml_lists(language, tok_to_annos, tokens)

            ml_anno = MultiLabelAnnotation(tok_list, ["_" for _ in tok_list], frame_list,
                                           [None for _ in tok_list])
            yield ml_anno


def prepare_ml_lists(language, tok_to_annos, tokens):
    tok_list = []
    frame_list = []
    for tok_idx, tok in enumerate(tokens):
        # split tokens that include punctuation
        split_tok = nltk.word_tokenize(tok.text, language=language)
        tok_list.extend(split_tok)
        tok_anno = []
        for anno in tok_to_annos.get(tok_idx, []):
            tok_anno.append(anno)
        frame_list.extend([list(tok_anno) for _ in split_tok])

    # remove annotations from final punctuation & solve BIO weird stuff
    for idx, (tok, frame_annos) in enumerate(zip(tok_list, frame_list)):
        if tok in ",.:;\"'`«»":
            to_delete = []
            for fa in frame_annos:
                if fa.startswith("T:"):
                    compare_fa = fa
                else:
                    compare_fa = "I" + fa[1:]

                if idx == len(tok_list) - 1:
                    to_delete.append(fa)
                elif compare_fa not in frame_list[idx + 1]:
                    to_delete.append(fa)

            for fa in to_delete:
                frame_annos.remove(fa)

        for fa_idx, fa in enumerate(frame_annos):

            if fa.startswith("B:"):
                # check if we had exactly the same label the token before
                if idx > 0 and fa in frame_list[idx - 1]:
                    frame_annos[fa_idx] = "I" + fa[1:]

    return frame_list, tok_list


def map_tokens_to_annotations(comm: Communication, situations: List[str], confidence_filter: float):
    tok_to_annos = defaultdict(list)
    for sit_idx, sit_uuid in enumerate(situations):
        situation = comm.situationMentionForUUID[sit_uuid]
        if situation.confidence < confidence_filter:
            continue

        frame_type = situation.situationKind
        tgt_tokens = situation.tokens.tokenIndexList

        if frame_type == "@@VIRTUAL_ROOT@@":
            continue

        for tok_id in tgt_tokens:
            tok_to_annos[tok_id].append(f"T:{frame_type}@{sit_idx:02}@@{situation.confidence}")
        for arg in situation.argumentList:
            if arg.confidence < confidence_filter:
                continue

            fe_type = arg.role
            fe_tokens = arg.entityMention.tokens.tokenIndexList
            for tok_n, tok_id in enumerate(fe_tokens):
                if tok_n == 0:
                    bio = "B"
                else:
                    bio = "I"
                tok_to_annos[tok_id].append(f"{bio}:{frame_type}:{fe_type}@{sit_idx:02}@@{arg.confidence}")
    return tok_to_annos


def map_sent_to_situation(comm):
    tok_uuid_to_situation = defaultdict(list)
    for situation in comm.situationMentionSetList:
        for mention in situation.mentionList:
            tok_uuid_to_situation[mention.tokens.tokenizationId.uuidString].append(mention.uuid.uuidString)
    return tok_uuid_to_situation


def main():
    file_in = sys.argv[1]
    language = sys.argv[2]
    output_directory = sys.argv[3]
    confidence_filter = float(sys.argv[4])
    split_by_migration_files = False

    file_in_base = os.path.basename(file_in)
    file_out = f"{output_directory}/lome_{file_in_base}"
    multi_label_annos = list(convert_file(file_in, language=language, confidence_filter=confidence_filter))
    multi_label_json = [dataclasses.asdict(anno) for anno in multi_label_annos]

    if split_by_migration_files:
        files = glob.glob("output/migration/split_data/split_dev10_sep_txt_files/*.orig.txt")
        files.sort(key=lambda f: int(f.split("/")[-1].rstrip(".orig.txt")))

        for anno, file in zip(multi_label_annos, files):
            basename = file.split("/")[-1].rstrip(".orig.txt")
            spl_file_out = f"{output_directory}/{basename}"
            with open(f"{spl_file_out}.txt", "w", encoding="utf-8") as f_txt:
                for line in anno.to_txt():
                    f_txt.write(line + os.linesep)
                f_txt.write(os.linesep)

    else:
        print(file_out)
        with open(f"{file_out}.json", "w", encoding="utf-8") as f_json:
            json.dump(multi_label_json, f_json, indent=4)

        with open(f"{file_out}.txt", "w", encoding="utf-8") as f_txt:
            for anno in multi_label_annos:
                for line in anno.to_txt():
                    f_txt.write(line + os.linesep)
                f_txt.write(os.linesep)


if __name__ == '__main__':
    main()