|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Text-centric preprocessing ops. |
|
|
|
All preprocessing ops should return a data processing functors. A data |
|
is represented as a dictionary of (TF) tensors. The functors output a modified |
|
dictionary. |
|
|
|
A commonly used key for the tokenized output is "labels". |
|
""" |
|
import functools |
|
import importlib |
|
|
|
from absl import logging |
|
from big_vision.datasets.imagenet import class_names as imagenet_class_names |
|
from big_vision.pp import ops_general |
|
from big_vision.pp import tokenizer as bv_tok |
|
from big_vision.pp import utils |
|
from big_vision.pp.registry import Registry |
|
import tensorflow as tf |
|
|
|
from tensorflow.io import gfile |
|
|
|
import sentencepiece |
|
SPProcessor = sentencepiece.SentencePieceProcessor |
|
|
|
import os |
|
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' |
|
import sentencepiece.sentencepiece_model_pb2 |
|
del os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] |
|
SPModelProto = sentencepiece.sentencepiece_model_pb2.ModelProto |
|
|
|
|
|
|
|
|
|
KNOWN_TOKENIZERS = { |
|
"mc4": |
|
"gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model", |
|
"cc_all": |
|
"gs://t5-data/vocabs/cc_all.32000/sentencepiece.model", |
|
"c4_en": |
|
"gs://t5-data/vocabs/cc_en.32000/sentencepiece.model", |
|
"t5": |
|
"gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model", |
|
"mt5": |
|
"gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model", |
|
} |
|
|
|
|
|
def create_tokenizer(model="c4_en", add_eos=True, add_bos=False): |
|
"""Creates a tokenizer which can be used in tfds.""" |
|
logging.info("Creating tokenizer: %s", model) |
|
with gfile.GFile(KNOWN_TOKENIZERS.get(model, model), "rb") as f: |
|
model = f.read() |
|
|
|
|
|
|
|
import tensorflow_text |
|
return tensorflow_text.SentencepieceTokenizer( |
|
model=model, add_eos=add_eos, add_bos=add_bos |
|
) |
|
|
|
|
|
def tokenize(input_text, tokenizer, max_len, *, pad_value, force_eos, |
|
multi_text=False): |
|
"""Tokenizes string, and adds `pad_value` if longer than `max_len`.""" |
|
|
|
def pad(tokens): |
|
|
|
if force_eos: |
|
tokens = tf.cond( |
|
tf.shape(tokens)[0] >= max_len, |
|
lambda: tf.concat( |
|
|
|
[tokens[:max_len - 1], tokens[-1:]], axis=0), |
|
lambda: tf.pad( |
|
tokens, [(0, max_len - tf.shape(tokens)[0])], |
|
constant_values=pad_value), |
|
) |
|
else: |
|
tokens = tokens[:max_len] |
|
tokens = tf.pad( |
|
tokens, [(0, max_len - tf.shape(tokens)[0])], |
|
constant_values=pad_value) |
|
tokens.set_shape([max_len]) |
|
return tokens |
|
|
|
tokens = tokenizer.tokenize(input_text) |
|
|
|
if multi_text: |
|
tokens = tokens.to_tensor(pad_value) |
|
tokens = tf.reshape(tokens, [-1, tf.shape(tokens)[-1]]) |
|
tokens = tf.map_fn(pad, tokens) |
|
|
|
final_shape = tf.concat([tf.shape(input_text), [max_len]], axis=0) |
|
return tf.reshape(tokens, final_shape) |
|
else: |
|
return pad(tokens) |
|
|
|
|
|
@Registry.register("preprocess_ops.tokenize") |
|
@utils.InKeyOutKey(indefault=None, outdefault="labels") |
|
def get_pp_tokenize( |
|
max_len, |
|
eos, |
|
model="c4_en", |
|
lower=True, |
|
sample_if_multi=True, |
|
pad_value="<pad>", |
|
add_bos=False |
|
): |
|
"""Tokenizes a text. |
|
|
|
Let's assume max_len=3 and id("</s>")=1, id("a")=2, then we have |
|
|
|
1. `eos="none", pad_value=0`: |
|
- "a" -> [2, 0, 0] |
|
- "aa" -> [2, 2, 0] |
|
- "aaa" -> [2, 2, 2] |
|
|
|
2. `eos="yes", pad_value=0`: |
|
- "a" -> [2, 1, 0] |
|
- "aa" -> [2, 2, 1] |
|
- "aaa" -> [2, 2, 2] |
|
|
|
This is usually used with generative models that need to learn when to |
|
properly predict a "</s>" (when the sentence is finished) and when to |
|
abstain (when the sentence is truncated). |
|
|
|
3. `eos="sticky", pad_value=0`: |
|
- "a" -> [2, 1, 0] |
|
- "aa" -> [2, 2, 1] |
|
- "aaa" -> [2, 2, 1] |
|
|
|
4. `eos="sticky", pad_value=1`: |
|
- "a" -> [2, 1, 1] |
|
- "aa" -> [2, 2, 1] |
|
- "aaa" -> [2, 2, 1] |
|
|
|
This is traditionally used with contrastive models that use the last token |
|
for embeddings, similarly to "cls" tokens in BERT-style models. |
|
|
|
Args: |
|
max_len: maximum length of the tokenized text. |
|
eos: Whether to add an "</s>" (end of sentence) token and whether to keep it |
|
when the sequence is longer than `max_len - 1`. See examples above for |
|
details. Valid values: "none", "yes", "sticky". |
|
model: a path to the pretrained sentencepiece model. |
|
lower: lowercase the text before tokenizing. |
|
sample_if_multi: If there's more than one, randomly pick one if this is |
|
True; otherwise pick all texts and keep the input's batch shape in result. |
|
pad_value: which token to pad the sequence with. If a string (for example |
|
`"<pad>"`), tokenize it and use its first token. Note that there is no |
|
guarantee to have any padding at the end of the sentence, if the sentence |
|
is longer than `max_len`. |
|
add_bos: adds beginning of sentence symbol. |
|
|
|
Returns: |
|
an op that outputs tokenized text. |
|
""" |
|
|
|
if eos not in ("yes", "none", "sticky"): |
|
raise ValueError(f"Invalid value for eos: '{eos}'.") |
|
|
|
tokenizer = create_tokenizer(model, add_eos=eos != "none", add_bos=add_bos) |
|
|
|
if isinstance(pad_value, str): |
|
pad_value = tokenizer.string_to_id(pad_value) |
|
|
|
def _pp_tokenize(txt): |
|
if sample_if_multi and tf.convert_to_tensor(txt).ndim: |
|
|
|
logging.warning("sample_if_multi is deprecated and will be removed." |
|
"Call `choice` (and maybe `setdefault`) instead.") |
|
txt = ops_general.get_choice(key="t")( |
|
ops_general.get_setdefault("t", "")({"t": txt}))["t"] |
|
|
|
if lower: |
|
txt = tf.strings.lower(txt) if sample_if_multi else tf.map_fn( |
|
tf.strings.lower, txt) |
|
|
|
return tokenize( |
|
txt, |
|
tokenizer, |
|
max_len, |
|
pad_value=pad_value, |
|
force_eos=eos == "sticky", |
|
multi_text=not sample_if_multi) |
|
|
|
return _pp_tokenize |
|
|
|
|
|
@Registry.register("preprocess_ops.coco_captions") |
|
def get_coco_captions(outkey="captions"): |
|
"""Extracts coco's captions from nested dict.""" |
|
|
|
def _pp_coco_captions(data): |
|
data[outkey] = data["captions"]["text"] |
|
return data |
|
|
|
return _pp_coco_captions |
|
|
|
|
|
@Registry.register("preprocess_ops.clip_i1k_label_names") |
|
@utils.InKeyOutKey(indefault="label", outdefault="labels") |
|
def get_pp_clip_i1k_label_names(): |
|
"""Convert i1k label numbers to strings, using CLIP's class names.""" |
|
|
|
def _pp_imagenet_labels(label): |
|
return tf.gather(imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES, label) |
|
|
|
return _pp_imagenet_labels |
|
|
|
|
|
@Registry.register("preprocess_ops.lower") |
|
@utils.InKeyOutKey(indefault="text", outdefault="text") |
|
def get_lower(): |
|
"""Lowercases text feature.""" |
|
|
|
def _pp_lower(text): |
|
return tf.strings.lower(text) |
|
|
|
return _pp_lower |
|
|
|
|
|
def _add_pieces(model_bytes, extra_pieces): |
|
"""Adds extra pieces to sentencpiece model specified by `model_bytes`.""" |
|
|
|
model = SPProcessor() |
|
model.LoadFromSerializedProto(model_bytes) |
|
unk_idx = model.PieceToId("<unk>") |
|
assert model.IdToPiece(unk_idx) == "<unk>", model.IdToPiece(unk_idx) |
|
|
|
model_proto = SPModelProto.FromString(model_bytes) |
|
idx_to_updated_piece = {} |
|
for piece in extra_pieces: |
|
|
|
|
|
piece = piece.replace(" ", "▁") |
|
spiece = model_proto.SentencePiece( |
|
piece=piece, |
|
|
|
score=0.0, |
|
type=model_proto.SentencePiece().Type.USER_DEFINED, |
|
) |
|
existing_idx = model.PieceToId(piece) |
|
if (existing_idx != unk_idx) ^ (piece == "<unk>"): |
|
idx_to_updated_piece[existing_idx] = spiece |
|
logging.info("Updating token at idx %d: %s", existing_idx, spiece.piece) |
|
else: |
|
model_proto.pieces.append(spiece) |
|
|
|
|
|
updated_pieces = [ |
|
idx_to_updated_piece.get(i, piece) |
|
for i, piece in enumerate(model_proto.pieces) |
|
] |
|
del model_proto.pieces[:] |
|
model_proto.pieces.extend(updated_pieces) |
|
|
|
return model_proto.SerializeToString() |
|
|
|
|
|
def _iterable(x): |
|
if isinstance(x, tf.RaggedTensor): |
|
return True |
|
if getattr(x, "ndim", 0) > 1: |
|
return True |
|
if isinstance(x, (list, tuple)) and not isinstance(x[0], (int, float)): |
|
return True |
|
return False |
|
|
|
|
|
@Registry.register("tokenizers.sp") |
|
class SentencepieceTokenizer(bv_tok.Tokenizer): |
|
"""Wraps a `tftext.SentencepieceTokenizer`. |
|
|
|
If you plan to use this tokenizer, please familiarize yourself with the test |
|
cases first. This is likely to save you a lot of troubles down the road, trust |
|
me! |
|
""" |
|
|
|
def __init__(self, model, tokensets=()): |
|
with gfile.GFile(KNOWN_TOKENIZERS.get(model, model), "rb") as f: |
|
model_bytes = f.read() |
|
extras = bv_tok.get_extra_tokens(tokensets) |
|
model_bytes = _add_pieces(model_bytes, extras) |
|
self._tok_sp = SPProcessor() |
|
self._tok_sp.LoadFromSerializedProto(model_bytes) |
|
self.extras = {self._tok_sp.PieceToId(x): x for x in extras} |
|
|
|
def to_int(self, text, *, bos=False, eos=False): |
|
def _single(s): |
|
return ( |
|
([self.bos_token] if bos else []) + |
|
self._tok_sp.EncodeAsIds(s) + |
|
([self.eos_token] if eos else []) |
|
) |
|
if isinstance(text, str): |
|
return _single(text) |
|
return type(text)([_single(s) for s in text]) |
|
|
|
def to_str(self, tokens, *, stop_at_eos=True): |
|
def _single(toks): |
|
toks = [int(t) for t in toks] |
|
if stop_at_eos: |
|
try: |
|
toks = toks[:toks.index(self.eos_token)] |
|
except ValueError: |
|
pass |
|
return self._tok_sp.DecodeIds(toks) |
|
if _iterable(tokens): |
|
return [_single(toks) for toks in tokens] |
|
return _single(tokens) |
|
|
|
def _check_known(self, piece): |
|
if (id_ := self._tok_sp.PieceToId(piece)) == self._tok_sp.unk_id(): |
|
logging.error("Piece '%s' is not known (unk=%s)!", piece, id_) |
|
return id_ |
|
|
|
def to_piece(self, idx): |
|
return self._tok_sp.IdToPiece(int(idx)) |
|
|
|
@property |
|
def pad_token(self): |
|
return self._tok_sp.pad_id() |
|
|
|
@property |
|
def eos_token(self): |
|
return self._tok_sp.eos_id() |
|
|
|
@property |
|
def bos_token(self): |
|
return self._tok_sp.bos_id() |
|
|
|
@property |
|
def vocab_size(self): |
|
return self._tok_sp.GetPieceSize() |
|
|
|
|
|
|
|
def to_int_tf_op(self, text, *, bos=False, eos=False): |
|
text = tf.convert_to_tensor(text) |
|
if text.ndim == 0: |
|
def fn(txt): |
|
string = txt.numpy().decode() |
|
return tf.constant(self.to_int(string, bos=bos, eos=eos), tf.int32) |
|
return tf.py_function(fn, [text], tf.int32) |
|
else: |
|
def fn(txt): |
|
strings = [s.decode() for s in txt.numpy().tolist()] |
|
toks = self.to_int(strings, bos=bos, eos=eos) |
|
return tf.ragged.constant(toks) |
|
out_type = tf.RaggedTensorSpec([tf.shape(text)[0], None], tf.int32) |
|
return tf.py_function(fn, [text], Tout=out_type) |
|
|
|
def to_str_tf_op(self, tokens, *, stop_at_eos=True): |
|
def single(t): |
|
fn = functools.partial(self.to_str, stop_at_eos=stop_at_eos) |
|
return tf.numpy_function(fn, [t], tf.string, stateful=False) |
|
if _iterable(tokens): |
|
return tf.map_fn(single, tokens, tf.string) |
|
return single(tokens) |
|
|