File size: 7,856 Bytes
0db3ac6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
from tqdm import tqdm
from transformers import FillMaskPipeline, RobertaTokenizerFast


MAX_CTX_LEN = 512

SPACE_PREFIX = 'Ġ'

PRONOUN_TOKENS = {
    'I', 'ĠI',
    'you', 'You', 'Ġyou', 'ĠYou',
    'he', 'He', 'Ġhe', 'ĠHe',
    'she', 'She', 'Ġshe', 'ĠShe',
    'it', 'It', 'Ġit', 'ĠIt',
    'we', 'We', 'Ġwe', 'ĠWe',
    'they', 'They', 'Ġthey', 'ĠThey',
    'my', 'My', 'Ġmy', 'ĠMy',
    'your', 'Your', 'Ġyour', 'ĠYour',
    'his', 'His', 'Ġhis', 'ĠHis',
    'her', 'Her', 'Ġher', 'ĠHer',
    'its', 'Its', 'Ġits', 'ĠIts',
    'our', 'Our', 'Ġour', 'ĠOur',
    'their', 'Their', 'Ġtheir', 'ĠTheir',
    'mine', 'Mine', 'Ġmine', 'ĠMine',
    'yours', 'Yours', 'Ġyours', 'ĠYours',
    'hers', 'Hers', 'Ġhers', 'ĠHers',
    'ours', 'Ours', 'Ġours', 'ĠOurs',
    'theirs', 'Theirs', 'Ġtheirs', 'ĠTheirs',
}


def count_tokens(tokenizer, text: str) -> int:
    """ return number of tokens in string """
    return len(tokenizer(text)['input_ids'])


def text_to_token_names(tokenizer, text: str) -> list[str]:
    inputs = tokenizer(text)
    ref_tokens = []
    for id in inputs["input_ids"]:
        token = tokenizer._convert_id_to_token(id)
        ref_tokens.append(token)
    return ref_tokens


def text_to_token_ids(tokenizer, text: str) -> list[str]:
    return tokenizer(text)["input_ids"]


def has_at_least_one_pronoun(tokenizer, text: str) -> bool:
    token_names = text_to_token_names(tokenizer, text)
    for pronoun_token in PRONOUN_TOKENS:
        if pronoun_token in token_names:
            return True
    return False


def chunk_to_contexts(tokenizer, text: str):
    lines = text.splitlines()
    
    for i in range(len(lines)):
        
        # add lines before and after for context
        ctx = [lines[i]]
        focus_line_idx = 0
        before_line_idx = i
        after_line_idx = i
        
        # try adding lines as context until we reach MAX_CTX_LEN
        while True:
            
            something_done = False
            
            # try adding a line before
            if before_line_idx - 1 >= 0:
                before_candidate = [lines[before_line_idx - 1]] + ctx
                assert len(before_candidate) == len(ctx) + 1
                if count_tokens(tokenizer, "\n".join(before_candidate)) < MAX_CTX_LEN:
                    ctx = before_candidate
                    focus_line_idx += 1
                    before_line_idx -= 1
                    something_done = True
            
            # try adding a line after
            if after_line_idx + 1 < len(lines):
                # after_candidate = ctx + "\n" + lines[after_line_idx + 1]
                after_candidate = ctx + [lines[after_line_idx + 1]]
                if count_tokens(tokenizer, "\n".join(after_candidate)) < MAX_CTX_LEN:
                    ctx = after_candidate
                    after_line_idx += 1
                    something_done = True
            
            # if we can't add any line, we're done
            if not something_done:
                break
        
        assert len("".join(f"{x}\n" for x in ctx).splitlines()) == len(ctx)
        
        yield "".join(f"{x}\n" for x in ctx), focus_line_idx


def mask_pronouns(tokenizer: RobertaTokenizerFast, text: str) -> tuple[str, list[str]]:
    """ replaces all pronouns in text with <mask> """
    token_names = text_to_token_names(tokenizer, text)
    
    masked_token_names = []
    original_pronouns = []
    
    for token_name in token_names:
        if token_name in PRONOUN_TOKENS:
            masked_token_names.append(tokenizer.mask_token)
            original_pronouns.append(token_name)
        else:
            masked_token_names.append(token_name)
    
    masked_text = tokenizer.decode(tokenizer.convert_tokens_to_ids(masked_token_names), skip_special_tokens=False)
    # remove start and end tokens
    return masked_text[len(tokenizer.bos_token):-len(tokenizer.eos_token)], original_pronouns


def uncase_token(token_name: str) -> str:
    token_name = token_name.replace(' ', '')
    token_name = token_name.replace(SPACE_PREFIX, '')
    return token_name.lower()


def uncase_mask_result(mask_result):
    uncased_token_probs = {uncase_token(k): 0 for k in PRONOUN_TOKENS}
    for guess in mask_result:
        uncased_token_str = uncase_token(guess['token_str'])
        if uncased_token_str not in uncased_token_probs:
            continue
        uncased_token_probs[uncased_token_str] += guess['score']
    return uncased_token_probs


def case_token_like(best_token_uncased: str, original_token: str, best_token_cased_str: str) -> str:
    """
    :param best_token_uncased: the uncased, unspaced token that's the best match
    :param original_token: the original token we are replacing
    :param best_token_cased_str: the token str that's the best match. used for some cap
    :return: 
    """
    space = (SPACE_PREFIX == original_token[0]) or (' ' == original_token[0])
    cap = original_token[1 if space else 0].isupper()
    if best_token_uncased == 'i':
        cap = True
    # if the original token was 'I', we can't use it for cap info
    if original_token in ['I', 'ĠI']:
        cap = best_token_cased_str.strip()[0].isupper()
    if cap:
        best_token_uncased = best_token_uncased[0].upper() + best_token_uncased[1:]
    if space:
        best_token_uncased = ' ' + best_token_uncased
    return best_token_uncased


def fix_pronouns_in_text(
        unmasker: FillMaskPipeline,
        tokenizer,
        text: str,
        alpha: float = 0.05,
        use_tqdm: bool = False,
        tqdm_kwargs=None
) -> str:
    """
    Fixes pronouns in MTL text
    :param unmasker: unmasker pipeline
    :param tokenizer: model tokenizer
    :param text: text to fix
    :param alpha: only replace the existing pronouns with probability less than alpha
    :param use_tqdm: show tqdm progress bar
    :param tqdm_kwargs: any tqdm args
    :return: the fixed text
    """
    if tqdm_kwargs is None:
        tqdm_kwargs = {}
    
    fixed_lines = []
    ctxs = list(chunk_to_contexts(tokenizer, text))
    ctxs_iter = tqdm(ctxs, smoothing=0.0, desc="Fixing pronouns", **tqdm_kwargs) if use_tqdm else ctxs
    
    for ctx, focus_line_idx in ctxs_iter:
        
        ctx_lines = ctx.splitlines()
        focus_line = ctx_lines[focus_line_idx]
        
        # we can skip focusing on lines without a pronoun
        if not has_at_least_one_pronoun(tokenizer, focus_line):
            fixed_lines.append(focus_line)
            continue
        
        # mask all pronouns
        masked_ctx, original_pronouns = mask_pronouns(tokenizer, ctx)
        
        # unmask pronouns
        mask_results = unmasker(masked_ctx)
        if isinstance(mask_results[0], dict):
            mask_results = [mask_results]
        
        unmasked_ctx = masked_ctx
        
        for i, mask_result in enumerate(mask_results):
            original_pronoun = original_pronouns[i]
            uncased_original = uncase_token(original_pronoun)
            uncased_result = uncase_mask_result(mask_result)
            
            # if what was there doesn't make any sense, replace it
            if uncased_result[uncased_original] < alpha:
                best_uncased_pronoun = max(uncased_result, key=uncased_result.get)
                # TODO: ensure correct type, possessive, adj, subject
                best_cased_pronoun = case_token_like(best_uncased_pronoun, original_pronoun, mask_result[0]['token_str'])
                unmasked_ctx = unmasked_ctx.replace(tokenizer.mask_token, best_cased_pronoun, 1)
            else:
                unmasked_ctx = unmasked_ctx.replace(tokenizer.mask_token, original_pronoun.replace(SPACE_PREFIX, ' '), 1)
        
        fixed_lines.append(unmasked_ctx.splitlines()[focus_line_idx])
    
    return "\n".join(fixed_lines)