from tqdm import tqdm from transformers import FillMaskPipeline, RobertaTokenizerFast MAX_CTX_LEN = 512 SPACE_PREFIX = 'Ġ' PRONOUN_TOKENS = { 'I', 'ĠI', 'you', 'You', 'Ġyou', 'ĠYou', 'he', 'He', 'Ġhe', 'ĠHe', 'she', 'She', 'Ġshe', 'ĠShe', 'it', 'It', 'Ġit', 'ĠIt', 'we', 'We', 'Ġwe', 'ĠWe', 'they', 'They', 'Ġthey', 'ĠThey', 'my', 'My', 'Ġmy', 'ĠMy', 'your', 'Your', 'Ġyour', 'ĠYour', 'his', 'His', 'Ġhis', 'ĠHis', 'her', 'Her', 'Ġher', 'ĠHer', 'its', 'Its', 'Ġits', 'ĠIts', 'our', 'Our', 'Ġour', 'ĠOur', 'their', 'Their', 'Ġtheir', 'ĠTheir', 'mine', 'Mine', 'Ġmine', 'ĠMine', 'yours', 'Yours', 'Ġyours', 'ĠYours', 'hers', 'Hers', 'Ġhers', 'ĠHers', 'ours', 'Ours', 'Ġours', 'ĠOurs', 'theirs', 'Theirs', 'Ġtheirs', 'ĠTheirs', } def count_tokens(tokenizer, text: str) -> int: """ return number of tokens in string """ return len(tokenizer(text)['input_ids']) def text_to_token_names(tokenizer, text: str) -> list[str]: inputs = tokenizer(text) ref_tokens = [] for id in inputs["input_ids"]: token = tokenizer._convert_id_to_token(id) ref_tokens.append(token) return ref_tokens def text_to_token_ids(tokenizer, text: str) -> list[str]: return tokenizer(text)["input_ids"] def has_at_least_one_pronoun(tokenizer, text: str) -> bool: token_names = text_to_token_names(tokenizer, text) for pronoun_token in PRONOUN_TOKENS: if pronoun_token in token_names: return True return False def chunk_to_contexts(tokenizer, text: str): lines = text.splitlines() for i in range(len(lines)): # add lines before and after for context ctx = [lines[i]] focus_line_idx = 0 before_line_idx = i after_line_idx = i # try adding lines as context until we reach MAX_CTX_LEN while True: something_done = False # try adding a line before if before_line_idx - 1 >= 0: before_candidate = [lines[before_line_idx - 1]] + ctx assert len(before_candidate) == len(ctx) + 1 if count_tokens(tokenizer, "\n".join(before_candidate)) < MAX_CTX_LEN: ctx = before_candidate focus_line_idx += 1 before_line_idx -= 1 something_done = True # try adding a line after if after_line_idx + 1 < len(lines): # after_candidate = ctx + "\n" + lines[after_line_idx + 1] after_candidate = ctx + [lines[after_line_idx + 1]] if count_tokens(tokenizer, "\n".join(after_candidate)) < MAX_CTX_LEN: ctx = after_candidate after_line_idx += 1 something_done = True # if we can't add any line, we're done if not something_done: break assert len("".join(f"{x}\n" for x in ctx).splitlines()) == len(ctx) yield "".join(f"{x}\n" for x in ctx), focus_line_idx def mask_pronouns(tokenizer: RobertaTokenizerFast, text: str) -> tuple[str, list[str]]: """ replaces all pronouns in text with """ token_names = text_to_token_names(tokenizer, text) masked_token_names = [] original_pronouns = [] for token_name in token_names: if token_name in PRONOUN_TOKENS: masked_token_names.append(tokenizer.mask_token) original_pronouns.append(token_name) else: masked_token_names.append(token_name) masked_text = tokenizer.decode(tokenizer.convert_tokens_to_ids(masked_token_names), skip_special_tokens=False) # remove start and end tokens return masked_text[len(tokenizer.bos_token):-len(tokenizer.eos_token)], original_pronouns def uncase_token(token_name: str) -> str: token_name = token_name.replace(' ', '') token_name = token_name.replace(SPACE_PREFIX, '') return token_name.lower() def uncase_mask_result(mask_result): uncased_token_probs = {uncase_token(k): 0 for k in PRONOUN_TOKENS} for guess in mask_result: uncased_token_str = uncase_token(guess['token_str']) if uncased_token_str not in uncased_token_probs: continue uncased_token_probs[uncased_token_str] += guess['score'] return uncased_token_probs def case_token_like(best_token_uncased: str, original_token: str, best_token_cased_str: str) -> str: """ :param best_token_uncased: the uncased, unspaced token that's the best match :param original_token: the original token we are replacing :param best_token_cased_str: the token str that's the best match. used for some cap :return: """ space = (SPACE_PREFIX == original_token[0]) or (' ' == original_token[0]) cap = original_token[1 if space else 0].isupper() if best_token_uncased == 'i': cap = True # if the original token was 'I', we can't use it for cap info if original_token in ['I', 'ĠI']: cap = best_token_cased_str.strip()[0].isupper() if cap: best_token_uncased = best_token_uncased[0].upper() + best_token_uncased[1:] if space: best_token_uncased = ' ' + best_token_uncased return best_token_uncased def fix_pronouns_in_text( unmasker: FillMaskPipeline, tokenizer, text: str, alpha: float = 0.05, use_tqdm: bool = False, tqdm_kwargs=None ) -> str: """ Fixes pronouns in MTL text :param unmasker: unmasker pipeline :param tokenizer: model tokenizer :param text: text to fix :param alpha: only replace the existing pronouns with probability less than alpha :param use_tqdm: show tqdm progress bar :param tqdm_kwargs: any tqdm args :return: the fixed text """ if tqdm_kwargs is None: tqdm_kwargs = {} fixed_lines = [] ctxs = list(chunk_to_contexts(tokenizer, text)) ctxs_iter = tqdm(ctxs, smoothing=0.0, desc="Fixing pronouns", **tqdm_kwargs) if use_tqdm else ctxs for ctx, focus_line_idx in ctxs_iter: ctx_lines = ctx.splitlines() focus_line = ctx_lines[focus_line_idx] # we can skip focusing on lines without a pronoun if not has_at_least_one_pronoun(tokenizer, focus_line): fixed_lines.append(focus_line) continue # mask all pronouns masked_ctx, original_pronouns = mask_pronouns(tokenizer, ctx) # unmask pronouns mask_results = unmasker(masked_ctx) if isinstance(mask_results[0], dict): mask_results = [mask_results] unmasked_ctx = masked_ctx for i, mask_result in enumerate(mask_results): original_pronoun = original_pronouns[i] uncased_original = uncase_token(original_pronoun) uncased_result = uncase_mask_result(mask_result) # if what was there doesn't make any sense, replace it if uncased_result[uncased_original] < alpha: best_uncased_pronoun = max(uncased_result, key=uncased_result.get) # TODO: ensure correct type, possessive, adj, subject best_cased_pronoun = case_token_like(best_uncased_pronoun, original_pronoun, mask_result[0]['token_str']) unmasked_ctx = unmasked_ctx.replace(tokenizer.mask_token, best_cased_pronoun, 1) else: unmasked_ctx = unmasked_ctx.replace(tokenizer.mask_token, original_pronoun.replace(SPACE_PREFIX, ' '), 1) fixed_lines.append(unmasked_ctx.splitlines()[focus_line_idx]) return "\n".join(fixed_lines)