jwyang
first commit
4121bec
raw
history blame
No virus
8.58 kB
import logging
import random
class InputExample(object):
"""A single training/test example for the language model."""
def __init__(self, guid, tokens_a, tokens_b=None, is_next=None,
lm_labels=None, img_id=None, is_img_match=None,
img_label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
tokens_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
tokens_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
"""
self.guid = guid
self.tokens_a = tokens_a
self.tokens_b = tokens_b
self.is_next = is_next # nextSentence
self.lm_labels = lm_labels # masked words for language model
self.img_id = img_id
self.is_img_match = is_img_match
self.img_label = img_label
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self, input_ids, input_mask, segment_ids, is_next,
lm_label_ids, img_feat_len, is_img_match):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.is_next = is_next
self.lm_label_ids = lm_label_ids
self.img_feat_len = img_feat_len
self.is_img_match = is_img_match
def random_word(tokens, tokenizer):
"""
Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
:param tokens: list of str, tokenized sentence.
:param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
:return: (list of str, list of int), masked tokens and related labels for LM prediction
"""
output_label = []
for i, token in enumerate(tokens):
prob = random.random()
# mask token with 15% probability
if prob < 0.15:
prob /= 0.15
# 80% randomly change token to mask token
if prob < 0.8:
tokens[i] = "[MASK]"
# 10% randomly change token to random token
elif prob < 0.9:
tokens[i] = random.choice(list(tokenizer.vocab.items()))[0]
# -> rest 10% randomly keep current token
# append current token to output (we will predict these later)
try:
output_label.append(tokenizer.vocab[token])
except KeyError:
# For unknown words (should not occur with BPE vocab)
output_label.append(tokenizer.vocab["[UNK]"])
logging.warning(
"Cannot find token '{}' in vocab. Using [UNK] insetad".format(
token))
else:
# no masking token (will be ignored by loss function later)
output_label.append(-1)
return tokens, output_label
def convert_example_to_features(args, example, max_seq_length, tokenizer,
img_feat_len):
"""
Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
IDs, LM labels, input_mask, CLS and SEP tokens etc.
:param args: parameter settings
:param img_feat_len: lens of actual img features
:param example: InputExample, containing sentence input as strings and is_next label
:param max_seq_length: int, maximum length of sequence.
:param tokenizer: Tokenizer
:return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
"""
tokens_a = example.tokens_a
tokens_b = None
if example.tokens_b:
tokens_b = example.tokens_b
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[:(max_seq_length - 2)]
tokens_a, t1_label = random_word(tokens_a, tokenizer)
if tokens_b:
tokens_b, t2_label = random_word(tokens_b, tokenizer)
# concatenate lm labels and account for CLS, SEP, SEP
if tokens_b:
lm_label_ids = ([-1] + t1_label + [-1] + t2_label + [-1])
else:
lm_label_ids = ([-1] + t1_label + [-1])
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambigiously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
if tokens_b:
assert len(tokens_b) > 0
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
lm_label_ids.append(-1)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
assert len(lm_label_ids) == max_seq_length
# image features
if args.max_img_seq_length > 0:
if img_feat_len > args.max_img_seq_length:
input_mask = input_mask + [1] * img_feat_len
else:
input_mask = input_mask + [1] * img_feat_len
pad_img_feat_len = args.max_img_seq_length - img_feat_len
input_mask = input_mask + ([0] * pad_img_feat_len)
lm_label_ids = lm_label_ids + [-1] * args.max_img_seq_length
if example.guid < 1:
logging.info("*** Example ***")
logging.info("guid: %s" % example.guid)
logging.info("tokens: %s" % " ".join([str(x) for x in tokens]))
logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
logging.info("LM label: %s " % lm_label_ids)
logging.info("Is next sentence label: %s " % example.is_next)
features = InputFeatures(input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
lm_label_ids=lm_label_ids,
is_next=example.is_next,
img_feat_len=img_feat_len,
is_img_match=example.is_img_match)
return features
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()