creative-help / rnnlm_model /tokenization_utils.py
roemmele's picture
Enabled control of generation parameters; created README.md
c886682
"""Tokenization utilities for RNNLM - entity extraction, replacement, and decoding."""
import re
import numpy as np
# RNG for adapt_tok_seq_ents when sampling from sub_ent_probs
_rng = np.random.RandomState(0)
def segment(encoder, seq):
doc = encoder(seq)
return [getattr(sent, 'text', getattr(sent, 'string', str(sent))).strip() for sent in doc.sents]
def tokenize(encoder, seq, lowercase=True, recognize_ents=False,
lemmatize=False, include_tags=[], include_pos=[], prepend_start=False):
seq = encoder(seq)
if recognize_ents: # merge named entities into single tokens
ent_start_idxs = {ent.start: ent for ent in seq.ents
if getattr(ent, 'text', getattr(ent, 'string', '')).strip()}
# combine each ent into a single token; this is pretty hard to read, but it works
seq = [ent_start_idxs[word_idx] if word_idx in ent_start_idxs else word
for word_idx, word in enumerate(seq)
if (not word.ent_type_ or word_idx in ent_start_idxs)]
def _wtext(w):
return getattr(w, 'text', getattr(w, 'string', str(w))).strip()
# Don't apply POS filtering to phrases (words with underscores)
if include_tags: # fine-grained POS tags
seq = [word for word in seq
if ("_" in _wtext(word) or word.tag_ in include_tags)]
if include_pos: # coarse-grained POS tags
seq = [word for word in seq
if ("_" in _wtext(word) or word.pos_ in include_pos)]
if lemmatize:
seq = [word.lemma_ if not _wtext(word).startswith('ENT_')
else _wtext(word) for word in seq]
# don't lowercase if token is an entity (entities will be of type span instead of token; or will be prefixed with 'ENT_' if already transformed to types)
elif lowercase:
seq = [_wtext(word).lower() if not _wtext(word).startswith('ENT_')
else _wtext(word) for word in seq]
else:
seq = [_wtext(word) for word in seq]
# some words may be empty strings, so filter
seq = [word for word in seq if word]
if prepend_start:
seq.insert(0, u"<START>")
return seq
def ent_counts_to_probs(ent_counts):
"""Convert entity counts to probabilities for sampling when adapting entities."""
return {ent_type: {ent: count * 1.0 / sum(counts.values())
for ent, count in counts.items()}
for ent_type, counts in ent_counts.items()}
def get_ents(encoder, seq, include_ent_types=('PERSON', 'NORP', 'ORG', 'GPE')):
'''return dict of all entities in seq mapped to their entity types, optionally labeled with gender for PERSON entities'''
ents = {}
ent_counts = {}
for ent in encoder(seq).ents:
ent_type = ent.label_
if ent_type in include_ent_types:
ent = getattr(ent, 'text', getattr(
ent, 'string', str(ent))).strip()
if ent: # not sure why, but whitespace can be detected as an ent, so need to check for this
ents[ent] = [ent_type]
if ent in ent_counts:
ent_counts[ent] += 1
else:
ent_counts[ent] = 1
ents[ent] = "_".join(ents[ent])
return ents, ent_counts
def number_ents(encoder, ents, ent_counts):
'''return dict of all entities in seq mapped to their entity types,
with numerical suffixes to distinguish entities of the same type'''
ent_counts = sorted([(count, ent, ents[ent])
for ent, count in ent_counts.items()])[::-1]
ent_type_counts = {}
num_ents = {}
for count, ent, ent_type in ent_counts:
tok_ent = tokenize(encoder, ent, lowercase=False)
coref_ent = [num_ent for num_ent in num_ents
if (tokenize(encoder, num_ent, lowercase=False)[0] == tok_ent[0]
or tokenize(encoder, num_ent, lowercase=False)[-1] == tok_ent[-1])
# treat ents with same first or last word as co-referring
and ents[num_ent] == ent_type]
if coref_ent:
num_ents[ent] = num_ents[coref_ent[0]]
else:
ent_type = ent_type.split("_")
if ent_type[0] in ent_type_counts:
ent_type_counts[ent_type[0]] += 1
else:
ent_type_counts[ent_type[0]] = 1
num_ents[ent] = ent_type
# insert number id after entity type (and before tag, if it exists)
num_ents[ent].insert(1, str(ent_type_counts[ent_type[0]] - 1))
num_ents[ent] = "_".join(num_ents[ent])
return num_ents
def replace_ents_in_seq(encoder, seq):
'''extract entities from seq and replace them with their entity types'''
ents, ent_counts = get_ents(encoder, seq)
ents = number_ents(encoder, ents, ent_counts)
seq = tokenize(encoder, seq, lowercase=False, recognize_ents=True)
# word can be Token or Span; get text for lookup
def _text(w):
return (getattr(w, 'text', None) or getattr(w, 'string', None) or str(w)).strip()
seq = ['ENT_' + ents[_text(word)] if _text(word)
in ents else _text(word) for word in seq]
seq = " ".join(seq)
return seq
def decode_num_seqs(encoder, lexicon_lookup, unk_word, seqs, max_new_sents=None, eos_tokens=[],
detokenize=False, ents=[], capitalize_ents=False, adapt_ents=False,
sub_ent_probs=None, begin_sentence=True):
if not seqs:
return []
if type(seqs[0]) not in (list, np.ndarray, tuple):
seqs = [seqs]
decoded_seqs = []
# transform numerical seq back into string (seq elements are token IDs)
for seq_idx, seq in enumerate(seqs):
# Flatten to list of Python ints (handles 2D tensors from model.generate, e.g. (1, seq_len))
if hasattr(seq, 'cpu'):
seq = seq.cpu()
if hasattr(seq, 'tolist'):
seq = seq.tolist()
elif seq and hasattr(seq[0], 'tolist'):
# list(tensor) gives list of row tensors - convert each to list
seq = [row.tolist() for row in seq]
else:
seq = list(seq)
# If 2D (batch, seq_len), take each row; else single sequence
if seq and isinstance(seq[0], list):
rows = seq
else:
rows = [seq]
def _to_int(x):
if isinstance(x, (list, tuple)):
return [_to_int(v) for v in x]
return int(x.item()) if hasattr(x, 'item') else int(x)
for row_idx, row in enumerate(rows):
tok_seq = []
flat_row = _to_int(row) if isinstance(
row, (list, tuple)) else [_to_int(row)]
if isinstance(flat_row[0], list):
flat_row = [v for sub in flat_row for v in (
sub if isinstance(sub, list) else [sub])]
for w in flat_row:
i = w if isinstance(w, int) else int(w)
tok_seq.append(
lexicon_lookup[i] if (0 <= i < len(lexicon_lookup) and lexicon_lookup[i])
else unk_word
)
seq = tok_seq
if adapt_ents: # replace ENT_* with entities from ents, or sub_ent_probs/UNK as fallback
ent_idx = min(seq_idx + row_idx, len(ents) - 1) if ents else 0
seq_ents = ents[ent_idx] if ents else {}
seq = adapt_tok_seq_ents(
seq, ents=seq_ents, sub_ent_probs=sub_ent_probs or {})
if detokenize: # apply rules for transforming token list into formatted sequence
if ents and capitalize_ents:
ent_idx = min(seq_idx + row_idx,
len(ents) - 1) if ents else 0
seq = detokenize_tok_seq(
encoder, seq, ents=ents[ent_idx], begin_sentence=begin_sentence)
else:
seq = detokenize_tok_seq(
encoder, seq, ents=[], begin_sentence=begin_sentence)
else:
# otherwise just join tokens with whitespace between each
seq = " ".join(seq)
if eos_tokens: # if filter_n_sents is a number, filter generated sequence to only the first N=filter_n_sents sentences
seq = filter_gen_seq(encoder, seq, eos_tokens=eos_tokens)
elif max_new_sents:
seq = filter_gen_seq(encoder, seq, n_sents=max_new_sents)
decoded_seqs.append(seq)
return decoded_seqs
def adapt_tok_seq_ents(seq, ents={}, sub_ent_probs={}):
# reverse ents so that types map to names
ents = {ent_type: ent for ent, ent_type in ents.items()}
adapted_seq_ents = {"_".join(token.split("_")[1:]): None
for token in seq if token.startswith('ENT_')}
if not adapted_seq_ents:
return seq
for seq_ent_type in {ent_type: adapted_ent for ent_type, adapted_ent in adapted_seq_ents.items() if not adapted_ent}:
if seq_ent_type in ents:
adapted_seq_ents[seq_ent_type] = ents[seq_ent_type]
del ents[seq_ent_type]
if ents:
for seq_ent_type in {ent_type: adapted_ent for ent_type, adapted_ent in adapted_seq_ents.items() if not adapted_ent}:
for ent_type, ent in ents.items():
if seq_ent_type.split("_")[0] in ent_type.split("_")[0]:
adapted_seq_ents[seq_ent_type] = ents[ent_type]
del ents[ent_type]
break
for seq_ent_type in {ent_type: adapted_ent for ent_type, adapted_ent in adapted_seq_ents.items() if not adapted_ent}:
if seq_ent_type.split("_")[0] in sub_ent_probs:
sub_ents, sub_probs = zip(
*sub_ent_probs[seq_ent_type.split("_")[0]].items())
rand_ent_idx = _rng.choice(len(sub_ents), p=np.array(sub_probs))
adapted_seq_ents[seq_ent_type] = sub_ents[rand_ent_idx]
# Use ANY available entity (any type) when no type-specific match found
all_entities = list(ents.values())
for base_type, type_ents in sub_ent_probs.items():
all_entities.extend(type_ents.keys())
for seq_ent_type in {ent_type: adapted_ent for ent_type, adapted_ent in adapted_seq_ents.items() if not adapted_ent}:
if all_entities:
adapted_seq_ents[seq_ent_type] = _rng.choice(all_entities)
else:
adapted_seq_ents[seq_ent_type] = "ENT_" + seq_ent_type
seq = [adapted_seq_ents["_".join(token.split("_")[1:])] if "_".join(
token.split("_")[1:]) in adapted_seq_ents else token for token in seq]
return seq
def detokenize_tok_seq(encoder, seq, ents=[], begin_sentence=True):
'''use simple rules for transforming list of tokens back into string
ents is optional list of words (named entities) that should be capitalized'''
seq = [sent.split() for sent
in segment(encoder, " ".join(seq))] # split sequence into sentences
detok_seq = []
for sent_idx, sent in enumerate(seq):
assert (type(sent) in (list, tuple))
if ents:
token_idx = 0
# capitalize all tokens that appear in cap_ents
while token_idx < len(sent):
for ent in ents:
ent = ent.split()
if sent[token_idx:token_idx + len(ent)] == [token.lower() for token in ent]:
# import pdb;pdb.set_trace()
sent[token_idx:token_idx + len(ent)] = list(ent)
token_idx += len(ent) - 1
break
token_idx += 1
detok_sent = " ".join(sent)
detok_sent = re.sub("\'", "'", detok_sent)
# capitalize first-person "I" pronoun
detok_sent = re.sub(r"(^| )i ", r"\1I ", detok_sent)
# rules for contractions (pattern: raw string for \s; replacement: no backslash)
detok_sent = re.sub(r" n'\s*t ", "n't ", detok_sent)
detok_sent = re.sub(r" '\s*d ", "'d ", detok_sent)
detok_sent = re.sub(r" '\s*s ", "'s ", detok_sent)
detok_sent = re.sub(r" '\s*ve ", "'ve ", detok_sent)
detok_sent = re.sub(r" '\s*ll ", "'ll ", detok_sent)
detok_sent = re.sub(r" '\s*m ", "'m ", detok_sent)
detok_sent = re.sub(r" '\s*re ", "'re ", detok_sent)
# rules for formatting punctuation
detok_sent = re.sub(" \.", ".", detok_sent)
detok_sent = re.sub(" \!", "!", detok_sent)
detok_sent = re.sub(" \?", "?", detok_sent)
detok_sent = re.sub(" ,", ",", detok_sent)
detok_sent = re.sub(" \- ", "-", detok_sent)
detok_sent = re.sub(" :", ":", detok_sent)
detok_sent = re.sub(" ;", ";", detok_sent)
detok_sent = re.sub("\$ ", "$", detok_sent)
detok_sent = re.sub("\' \'", "\'\'", detok_sent)
detok_sent = re.sub("\` \`", "\`\`", detok_sent)
# replace repeated single quotes with double quotation mark.
detok_sent = re.sub("\'\'", "\"", detok_sent)
detok_sent = re.sub("\`\`", "\"", detok_sent)
# filter repetitive characters
detok_sent = re.sub(r'(["\']\s*){2,}', '" ', detok_sent)
# map each opening puncutation mark to closing mark
punc_pairs = {"\'": "\'", "\'": "\'",
"`": "\'", "\"": "\"", "(": ")", "[": "]"}
open_punc = []
char_idx = 0
while char_idx < len(detok_sent): # check for quotes and parenthesis
char = detok_sent[char_idx]
# end quote/parenthesis
if open_punc and char == punc_pairs[open_punc[-1]]:
if char_idx > 0 and detok_sent[char_idx - 1] == " ":
detok_sent = detok_sent[:char_idx -
1] + detok_sent[char_idx:]
open_punc.pop()
elif char in punc_pairs:
if char_idx < len(detok_sent) - 1 and detok_sent[char_idx + 1] == " ":
open_punc.append(char)
detok_sent = detok_sent[:char_idx +
1] + detok_sent[char_idx + 2:]
if char_idx < len(detok_sent) and detok_sent[char_idx] == char:
char_idx += 1
detok_sent = detok_sent.strip()
# capitalize first alphabetic character if begin_sentence is True
if begin_sentence:
for char_idx, char in enumerate(detok_sent):
if char.isalpha():
detok_sent = detok_sent[:char_idx +
1].upper() + detok_sent[char_idx + 1:]
break
detok_seq.append(detok_sent)
detok_seq = " ".join(detok_seq)
contraction_patterns = ("'s", "'re", "'ve", "'d", "'ll", "'m", "n't")
punctuation_patterns = (".", "!", "?", ",", "-", ":", ";", ")", "]")
# Only prepend space if detok_seq doesn't start with these
starts_with_pattern = detok_seq.startswith(
contraction_patterns) or detok_seq.startswith(punctuation_patterns)
if not starts_with_pattern and detok_seq:
detok_seq = " " + detok_seq
return detok_seq
def filter_gen_seq(encoder, seq, n_sents=1, eos_tokens=[]):
'''given a generated sequence, filter so that only the first n_sents are included in final generated sequence'''
leading_space = seq.startswith(" ") if seq else False
if eos_tokens: # if end-of-sentence tokens given, cut off sequence at first occurrence of one of these tokens; otherwise use segmenter to infer sentence boundaries
doc = encoder(seq)
for idx, word in enumerate(doc):
wtext = getattr(word, 'text', getattr(
word, 'string', str(word))).strip()
if wtext in eos_tokens:
span = doc[:idx + 1]
seq = getattr(span, 'text', getattr(
span, 'string', str(span))).strip()
break
else:
seq = getattr(doc, 'text', getattr(doc, 'string', str(doc)))
else:
sentences = segment(encoder, seq)
n = n_sents
seq = ""
while n <= len(sentences):
seq = " ".join(sentences[:n]).strip()
if seq:
break
n += 1
if not seq and sentences:
seq = " ".join(sentences).strip()
if leading_space and seq:
seq = " " + seq.lstrip()
return seq