import logging from bisect import bisect_left from collections import OrderedDict import cv2 import numpy as np import torch from m4.training.utils import FAKE_TOKEN_AROUND_IMAGE_V2, IMAGE_TOKEN, _convert_to_rgb logger = logging.getLogger(__name__) # Hyper-parameters _IMAGE_BONUS_VALUE = 2 # The bonus value for tokens preceding the image token _MIN_LENGTH_DOCUMENTS_TO_PACK = ( 5 # Minimum lengths of documents to pack together (lenghts is measures in number of tokens) ) def incremental_to_binary_attention_mask(incremental_mask, num_classes=-1): # This function converts: [-1, 0, 1] => [[0, 0], [1, 0], [0, 1]] # If any of images index are more than num_classes, set them to -1. # Words after the max number of images allowed have been seen don't attend on anything if num_classes != -1: incremental_mask[incremental_mask >= num_classes] = -1 negatives = incremental_mask == -1 incremental_mask[negatives] = 0 attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes) attn_mask[negatives, :] = 0 return attn_mask def image_attention_mask_for_packed_input_ids(input_ids, tokenizer): image_attention_mask = torch.full_like(input_ids, fill_value=-1) next_image_attention_mask = torch.full_like(input_ids, fill_value=-1) image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) eod_token_id = tokenizer.eos_token_id for batch_idx in range(input_ids.size(0)): count = -1 seen_eod = False for idx, token_id in enumerate(input_ids[batch_idx]): if token_id == image_token_id: count += 1 image_attention_mask[batch_idx][idx] = count seen_eod = False else: image_attention_mask[batch_idx][idx] = count if seen_eod: image_attention_mask[batch_idx][idx] = -1 if token_id == eod_token_id: seen_eod = True for batch_idx in range(input_ids.size(0)): count = -1 seen_eod = False for idx in range(input_ids[batch_idx].size(0) - 1, -1, -1): token_id = input_ids[batch_idx][idx] if token_id == image_token_id: count += 1 next_image_attention_mask[batch_idx][idx] = count seen_eod = False else: next_image_attention_mask[batch_idx][idx] = count if token_id == eod_token_id: seen_eod = True if seen_eod: next_image_attention_mask[batch_idx][idx] = -1 non_negative_indices = next_image_attention_mask[batch_idx] != -1 next_image_attention_mask[batch_idx][non_negative_indices] -= count next_image_attention_mask[batch_idx][non_negative_indices] *= -1 return image_attention_mask, next_image_attention_mask def laplacian_blur_detection(image, threshold=0.0): # compute the Laplacian of the image and then return the focus # measure, which is simply the variance of the Laplacian if threshold == 0.0: return False image = np.array(image) if len(image.shape) == 3 and image.shape[2] == 3: gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) return cv2.Laplacian(gray, cv2.CV_64F).var() < threshold else: # Don't remove grayscale images return False def fft_blur_detection(image, size=50, threshold=0.0): if threshold == 0.0: return False (h, w) = image.shape (cX, cY) = (int(w / 2.0), int(h / 2.0)) fft = np.fft.fft2(image) fftShift = np.fft.fftshift(fft) fftShift[cY - size : cY + size, cX - size : cX + size] = 0 fftShift = np.fft.ifftshift(fftShift) recon = np.fft.ifft2(fftShift) magnitude = 20 * np.log(np.abs(recon)) mean = np.mean(magnitude) return mean < threshold def split_pack_and_pad( sample, tokenizer, max_seq_len, image_transform, max_num_images, max_num_samples_per_document=10, prefix_seed=(0, 0), is_blurred_fn=None, blur_threshold=0.0, add_begin_of_doc_token=False, add_end_of_doc_token=True, max_num_images_per_document=None, ): """ Return a batch of samples in the format expected by the model which includes `input_ids`, `pixel_values`, `attention_mask`, `image_attention_mask`, and `next_image_attention_mask`. The `input_ids` are sampled from the document to ensure it has `max_seq_len` tokens otherwise, the shorter documents are packed together. For each document, we sample a maximum of `max_num_samples_per_document` or `max_num_samples_for_curr_document` (where the latter is proportional to the length of the document and inversely proportional to the length of subsequences) `input_ids` with sequence length `max_seq_len` from the document. This means that each sample sampled can have different start index. Based on the start index of sample that has been sampled, we also sample a maximum of `max_num_images` images from the document. If there are less than `max_num_images` images in the document, we pad the images with zeros. The start indexes are skewed towards subsequences that contain images. Args: sample (Dict): A sample object containing the document with images and text. tokenizer (PretrainedTokenizer): Text tokenizer to be used. max_seq_len (int): Maximum sequence length of the returned text tokens. image_transform (Callable): Transform to be applied on the images max_num_images (int): Maximum number of images to be sampled per sample. If less, they are padded with zeros. max_num_samples_per_document (int, optional): Maximum number of samples per document to be sampled. Defaults to 10. prefix_seed: Prefix seed sequence for "reproducible randomness" in calls to `np.random.choice` Returns: _type_: _description_ """ text_batch = sample["texts"] image_batch = sample.get("image_embeddings", None) is_raw_images = False if image_batch is None: image_batch = sample.get("images", None) is_raw_images = True if image_batch is None: raise ValueError("Either image_embeddings or images must be present in the sample") image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) last_was_image = False if is_blurred_fn is None: is_blurred_fn = fft_blur_detection all_images = [] all_texts = [] for raw_images, raw_texts in zip(image_batch, text_batch): # Filter ones that don't have either one image and one text word if not any(raw_images) or not any(raw_texts): continue if max_num_images_per_document: num_images = sum([1 if image is not None else 0 for image in raw_images]) if num_images > max_num_images_per_document: continue any_blurred = False if is_raw_images and blur_threshold > 0.0: for image in raw_images: if image is not None: image = _convert_to_rgb(image) any_blurred = any_blurred or is_blurred_fn(image, threshold=blur_threshold) if any_blurred: break if any_blurred: continue inds_of_texts_to_split = [ i for i, text in enumerate(raw_texts) if text is not None and isinstance(text, str) and "END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED" in text ] if inds_of_texts_to_split: splitted_raw_images, splitted_raw_texts = [], [] previous_i = 0 for i in inds_of_texts_to_split: splitting = raw_texts[i].split("END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED") part1, part2 = splitting[0], splitting[-1] sub_doc_images = raw_images[previous_i:i] + [None] sub_doc_texts = raw_texts[previous_i:i] + [part1.strip()] if not any(sub_doc_images): # This can happen if all images in raw_images[0:i] are all None continue splitted_raw_images.append(sub_doc_images) splitted_raw_texts.append(sub_doc_texts) if part2.strip() == "": previous_i = i + 1 else: raw_texts[i] = part2.strip() previous_i = i if previous_i < len(raw_images) and any(raw_images[previous_i:]): splitted_raw_images.append(raw_images[previous_i:]) splitted_raw_texts.append(raw_texts[previous_i:]) else: splitted_raw_images, splitted_raw_texts = [raw_images], [raw_texts] # Sanity check if [len(ims) for ims in splitted_raw_images] != [len(txts) for txts in splitted_raw_texts]: raise ValueError( "Number of images and texts don't match after splitting on `END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED`." " Something core went wrong during the splitting and needs to be fixed." ) for s_r_ims, s_r_txts in zip(splitted_raw_images, splitted_raw_texts): images, web_text = [], "" for image, text in zip(s_r_ims, s_r_txts): if text is None and image is None: continue if image is not None: web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{IMAGE_TOKEN}" if is_raw_images: images.append(image_transform(image)) else: images.append(torch.tensor(image)) last_was_image = True elif text is not None: if last_was_image: web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{text}" last_was_image = False else: web_text += f" {text}" if web_text != "" else text if last_was_image: web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}" web_text = web_text.strip(" ") # This is mostly a sanity check. Cases like that should not happen at that point. if web_text == "" or len(images) == 0: continue images = torch.stack(images) all_images.append(images) web_text_ids = tokenizer.encode(web_text, add_special_tokens=False) if add_end_of_doc_token: web_text_ids += [tokenizer.eos_token_id] if add_begin_of_doc_token: web_text_ids = [tokenizer.bos_token_id] + web_text_ids all_texts.append(web_text_ids) output_input_ids = [] output_images = [] output_attention_masks = [] output_num_images = [] output_num_text_tokens = [] input_ids_to_pack = [] images_to_pack = [] for images, text in zip(all_images, all_texts): # We save all the documents which are shorter than the max_seq_len to pack them together. if len(text) <= max_seq_len: if len(text) < _MIN_LENGTH_DOCUMENTS_TO_PACK: # Filter out extremely short sequences continue input_ids_to_pack.extend(text) images_to_pack.extend(images) else: # Computing the bonus scores for tokens near images to skew the sampling towards them # The main idea is to give a bonus to tokens that are closely before an image token, so that these tokens have more chance to be sampled. # Bonuses are computed for each image, which means a given token can receive bonuses from multiple images if this token is closely preceding multiple images. # We sum all the bonuses and L1 normalized along the seq_len axis to get a probability distribution. # Each token start with a regular bonus of 1, which corresponds to the uniform distribution over the sequence when there are no bonuses added. # Now the remaining question is which precedding tokens do we distribue bonuses to. # We first observe that for the sampled sub-sequence to be considered valid (i.e. sub-sequence contains an image), the start index can only be among [image_idx - max_seq_len + 1, image_idx]. # For the sake of the explanation, let's split the [image_idx - max_seq_len + 1, image_idx] interval in 3 parts: left, middle and right (in increasing order). # If we give bonuses to the tokens just before the image (right part), then we are favoring p_next=0 because only the tokens after the image have an image to attend to. # In practice, images will tend to be at the beginning of the sampled sub-sequence. # If we give bonuses very far before the image (left part), then we are favoring p_next=1 because only the tokens before the image gave an image to attend to. # In practice, images will tend to be at the end of the sampled sub-sequence. # To avoid choosing favoring p_next=0 or p_next=1, we can give bonuses to the tokens in the middle part. # In practise, images will tend to be in the middle of the sampled sequence. # Ultimately, we don't want to skew the distribution fed to model in that way (i.e. whether images are in the beginning, middle or end of the sampled sub-sequence), # and have all these cases represented equally in the data. So the easiest is to distribute a bonus to all of the max_seq_len tokens preceding the image. all_scores = np.array([1] * len(text)) for img_token_idx in np.where(np.array(text) == image_token_id)[0]: all_scores[max(0, img_token_idx - max_seq_len) : img_token_idx + 1] += _IMAGE_BONUS_VALUE # all_scores = np.clip(all_scores, a_min=1, a_max=3 * _IMAGE_BONUS_VALUE * max_num_images + 1) # We can optionally clip the bonuses to avoid having too high values (i.e. outliers documents) all_scores = all_scores[:-_MIN_LENGTH_DOCUMENTS_TO_PACK] # The number of samples is proportional to the length of the text and inversely proportional to the maximum sequence length max_num_samples_for_curr_document = len(text) // max_seq_len # Set "reproducible randomness" by creating an np.default_rng seeded by (main seed, epoch, rank_idx, worker_idx, mapped_batch_index, text len) choices = np.random.default_rng(seed=list(prefix_seed) + [len(text)]).choice( range(len(text) - _MIN_LENGTH_DOCUMENTS_TO_PACK), # shorter sub-sequences are reserved for packing min( len(text) - max_seq_len, 2 * max_num_samples_per_document ), # Sampling more than necessary and then breaking out of the for loop once we have enough samples p=all_scores / np.linalg.norm(all_scores, ord=1), replace=False, ) nb_effective_sequences_out_of_sampling = 0 for start_index in choices: image_start_index = text[:start_index].count(image_token_id) text_sub_sequence = text[start_index : start_index + max_seq_len] image_count = text_sub_sequence.count(image_token_id) if image_count == 0: # Skip if there are no images in the sequence continue if len(text_sub_sequence) < max_seq_len: # If the sub-sequence is shorter than max_seq_len, we reserve it for packing # It necessarily mean that the sub-sequence was sampled towards the end of the document, # which implies that we only need the `image_start_index` and not the `image_end_index` if text_sub_sequence.count(image_token_id) != len(images[image_start_index:]): # A safeguard for this logger.warning( "Skipping this sample because of mismatch in actual number of images and " "the '' tokens in the text" ) continue input_ids_to_pack.extend(text_sub_sequence) images_to_pack.extend(images[image_start_index:]) continue current_images = images[image_start_index : image_start_index + min(max_num_images, image_count)] if len(current_images) != min(max_num_images, image_count): # A safeguard for something off about this document, maybe `` tag that # by there from before or some issue in parsing the image? logger.warning( "Skipping this sample because of mismatch in actual number of images and " "the '' tokens in the text" ) break padded_image_tensor = torch.zeros(max_num_images, *images.size()[1:]) padded_image_tensor[: min(max_num_images, image_count)] = current_images output_images.append(padded_image_tensor) output_num_images.append(min(max_num_images, image_count)) output_input_ids.append(torch.tensor(text_sub_sequence)) output_num_text_tokens.append(len(text_sub_sequence)) attention_mask = torch.ones((max_seq_len,), dtype=torch.long) output_attention_masks.append(attention_mask) nb_effective_sequences_out_of_sampling += 1 if nb_effective_sequences_out_of_sampling >= min( max_num_samples_for_curr_document, max_num_samples_per_document ): # We got all the samples we need for this document, so breaking out break # Pack the remaining sequences from `input_ids_to_pack` x `images_to_pack` if input_ids_to_pack: image_counter = 0 for i in range(0, len(input_ids_to_pack), max_seq_len): current_input_ids = input_ids_to_pack[i : i + max_seq_len] unpadded_seq_len = len(current_input_ids) num_images = current_input_ids.count(image_token_id) if num_images == 0: continue current_images = images_to_pack[image_counter : image_counter + num_images] image_counter += num_images if unpadded_seq_len < max_seq_len: padded_input_ids = [tokenizer.pad_token_id] * max_seq_len padded_input_ids[:unpadded_seq_len] = current_input_ids current_input_ids = padded_input_ids elif unpadded_seq_len > max_seq_len: # This case has no purpose other than safeguard continue try: current_images = torch.stack(current_images)[:max_num_images] except Exception: continue padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:]) padded_image_tensor[: current_images.size(0)] = current_images attention_mask = torch.zeros((max_seq_len,), dtype=torch.long) attention_mask[:unpadded_seq_len] = 1 output_images.append(padded_image_tensor) output_input_ids.append(torch.tensor(current_input_ids)) output_num_text_tokens.append(unpadded_seq_len) output_num_images.append(min(max_num_images, num_images)) output_attention_masks.append(attention_mask) if len(output_images) == 0 or len(output_input_ids) == 0: result = { "input_ids": torch.tensor([], dtype=torch.long), "attention_mask": torch.tensor([], dtype=torch.bool), "image_attention_mask": torch.tensor([], dtype=torch.bool), "next_image_attention_mask": torch.tensor([], dtype=torch.bool), "num_images": torch.tensor([], dtype=torch.long), "num_text_tokens": torch.tensor([], dtype=torch.long), } if is_raw_images: result["pixel_values"] = torch.tensor([], dtype=torch.float32) else: result["image_embeddings"] = torch.tensor([], dtype=torch.float32) return result output_input_ids = torch.stack(output_input_ids) output_images = torch.stack(output_images) output_attention_masks = torch.stack(output_attention_masks) image_attention_mask, next_image_attention_mask = image_attention_mask_for_packed_input_ids( output_input_ids, tokenizer ) image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=max_num_images) next_image_attention_mask = incremental_to_binary_attention_mask( next_image_attention_mask, num_classes=max_num_images ) result = { "input_ids": output_input_ids, "attention_mask": output_attention_masks, "image_attention_mask": image_attention_mask, "next_image_attention_mask": next_image_attention_mask, "num_images": torch.tensor(output_num_images), "num_text_tokens": torch.tensor(output_num_text_tokens), } if is_raw_images: result["pixel_values"] = output_images else: result["image_embeddings"] = output_images return result def split_and_pad_pmd( sample, tokenizer, max_seq_len, image_transform, max_num_images, prefix_seed=(0, 0), is_blurred_fn=None, blur_threshold=0.0, prob_image_at_end=0.5, # If 1, the token is always added at the end of the text # If set to -1, all padding will be tolerated. If set to 0, no padding will be tolerated. padding_tolerance=-1, add_begin_of_doc_token=False, add_end_of_doc_token=True, ): if is_blurred_fn is None: is_blurred_fn = fft_blur_detection text_batch = sample["text"] image_batch = sample.get("image_embedding", None) is_raw_images = False if image_batch is None: image_batch = sample.get("image", None) is_raw_images = True filtered_image_batch = [] filtered_input_ids = [] # Define whether for the current PMD batch whether the images will be at the start or at the end. rng = np.random.default_rng(seed=list(prefix_seed)) is_image_at_end = False # rng.random is between 0 and 1, so if prob_image_at_end is 1, random value will # always be less than `prob_image_at_end` and `is_image_at_end` will always be True. # This means that images will always be at the end of the text. if rng.random() < prob_image_at_end: is_image_at_end = True for image, text in zip(image_batch, text_batch): if text is None or image is None: continue if is_raw_images and is_blurred_fn(image, threshold=blur_threshold): continue sample_text = f"{FAKE_TOKEN_AROUND_IMAGE_V2}{IMAGE_TOKEN}{FAKE_TOKEN_AROUND_IMAGE_V2}" # Remove trailing and leading whitespaces, including newlines and tabs text = text.strip() if is_image_at_end: sample_text = f"{text}{sample_text}" else: sample_text = f"{sample_text}{text}" sample_input_ids = tokenizer.encode(sample_text, add_special_tokens=False) if add_end_of_doc_token: sample_input_ids += [tokenizer.eos_token_id] if add_begin_of_doc_token: sample_input_ids = [tokenizer.bos_token_id] + sample_input_ids filtered_image_batch.append(image) filtered_input_ids.append(sample_input_ids) # sort by length of text and save same length elements in a mapping so we # can retrieve candidates later. filtered_image_batch, filtered_input_ids = zip( *sorted(zip(filtered_image_batch, filtered_input_ids), key=lambda x: len(x[1])) ) mapping_by_len = OrderedDict() for i, sample_input_ids in enumerate(filtered_input_ids): if len(sample_input_ids) not in mapping_by_len: mapping_by_len[len(sample_input_ids)] = [] mapping_by_len[len(sample_input_ids)].append((filtered_image_batch[i], sample_input_ids)) all_images = [] all_texts = [] all_attention_masks = [] all_num_images = [] all_num_text_tokens = [] current_text = [] current_images = [] while True: current_lens = list(mapping_by_len.keys()) if len(current_text) > 0: # Now we try to do a binary search to find the biggest sequence that # we can fit into the current sequence. # This will eventually use up bigger sequences faster which is good # and leave smaller sequences to pack with each other later. diff = max_seq_len - len(current_text) if len(current_lens) == 0: possible_index = -1 else: possible_index = bisect_left(current_lens, diff) if possible_index == len(current_lens) or current_lens[possible_index] != diff: possible_index -= 1 if possible_index >= 0: best_possible_length = current_lens[possible_index] image, sample_input_ids = mapping_by_len[best_possible_length].pop(0) # If we have used up all the samples of a certain length, remove # that length from the mapping. if len(mapping_by_len[best_possible_length]) == 0: del mapping_by_len[best_possible_length] current_text.extend(sample_input_ids) if is_raw_images: current_images.append(image_transform(image)) else: current_images.append(torch.tensor(image)) elif diff > padding_tolerance and padding_tolerance != -1: # If we are here, it means that we still have padding left # and we have exhausted our current unique options that will allow us to # fill this sequence completely. # So, we will try to fill the sequence with whatever we get from the unchanged # copy of all sequences. while diff > padding_tolerance: # Find a random sequence to fit # Why we need to add more stuff to prefix seed? # prefix_seed will be same in the same batch which means that it might sample # same thing again and again if there are multiple cases of padding in the # same batch which means we need to make this part as random as possible. rng = np.random.default_rng( prefix_seed + ( diff, len(current_text), len(all_texts), all_num_images, ) ) choice = rng.choice(range(len(filtered_input_ids))) image, sample_input_ids = filtered_image_batch[choice], filtered_input_ids[choice] current_text.extend(sample_input_ids) if is_raw_images: current_images.append(image_transform(image)) else: current_images.append(torch.tensor(image)) diff = max_seq_len - len(current_text) # In the next top-level while loop iteration, this should go into the else # clause which should also handle the sequences longer than max_seq_len else: current_images = torch.stack(current_images) padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:]) padded_image_tensor[: current_images.size(0)] = current_images[ : min(max_num_images, current_images.size(0)) ] all_num_images.append(min(max_num_images, current_images.size(0))) all_images.append(padded_image_tensor) padded_input_ids = torch.full((max_seq_len,), tokenizer.pad_token_id) current_max_len = min(max_seq_len, len(current_text)) padded_input_ids[:current_max_len] = torch.tensor(current_text)[:current_max_len] all_num_text_tokens.append(current_max_len) all_texts.append(padded_input_ids) attention_mask = torch.zeros((max_seq_len,), dtype=torch.long) attention_mask[: len(current_text)] = 1 all_attention_masks.append(attention_mask) # Make sure to reset the current text and images. current_images = [] current_text = [] if len(current_lens) == 0: break else: # A case where we might not have any samples left over after the initial filtering step. if len(current_lens) == 0: break image, sample_input_ids = mapping_by_len[current_lens[-1]].pop(0) if len(mapping_by_len[current_lens[-1]]) == 0: del mapping_by_len[current_lens[-1]] current_text = sample_input_ids[:max_seq_len] if is_raw_images: current_images = [image_transform(image)] else: current_images = [torch.tensor(image)] if len(all_images) == 0 or len(all_texts) == 0: result = { "input_ids": torch.tensor([], dtype=torch.long), "attention_mask": torch.tensor([], dtype=torch.bool), "image_attention_mask": torch.tensor([], dtype=torch.bool), "num_images": torch.tensor([], dtype=torch.long), "num_text_tokens": torch.tensor([], dtype=torch.long), } if is_raw_images: result["pixel_values"] = torch.tensor([], dtype=torch.float32) else: result["image_embeddings"] = torch.tensor([], dtype=torch.float32) return result all_texts = torch.stack(all_texts) all_images = torch.stack(all_images) all_attention_masks = torch.stack(all_attention_masks) image_attention_mask, next_image_attention_mask = image_attention_mask_for_packed_input_ids(all_texts, tokenizer) image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=max_num_images) next_image_attention_mask = incremental_to_binary_attention_mask( next_image_attention_mask, num_classes=max_num_images ) output = { "input_ids": all_texts, "attention_mask": all_attention_masks, "image_attention_mask": image_attention_mask, "num_images": torch.tensor(all_num_images), "num_text_tokens": torch.tensor(all_num_text_tokens), } if is_raw_images: output["pixel_values"] = all_images else: output["image_embeddings"] = all_images if is_image_at_end: # Set the correct attention mask based on whether the image is at the start # or not. When it is at the end, we need next image attention mask. output["image_attention_mask"] = next_image_attention_mask return output # Copied from https://github.com/google-research/text-to-text-transfer-transformer/blob/main/t5/data/preprocessors.py def random_spans_helper( inputs_length, noise_density, mean_noise_span_length, extra_tokens_per_span_inputs, extra_tokens_per_span_targets, verbose=False, ): """Training parameters to avoid padding with random_spans_noise_mask. When training a model with random_spans_noise_mask, we would like to set the other training hyperparmeters in a way that avoids padding. This function helps us compute these hyperparameters. We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens, and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens. This function tells us the required number of tokens in the raw example (for split_tokens()) as well as the length of the encoded targets. Note that this function assumes the inputs and targets will have EOS appended and includes that in the reported length. Args: inputs_length: an integer - desired length of the tokenized inputs sequence noise_density: a float mean_noise_span_length: a float extra_tokens_per_span_inputs: an integer extra_tokens_per_span_targets: an integer verbose: a bool indicating whether to log sequence lengths Returns: tokens_length: length of original text in tokens targets_length: an integer - length in tokens of encoded targets sequence """ if extra_tokens_per_span_inputs != 1: raise NotImplementedError( "extra_tokens_per_span_inputs != 1 not supported yet. You need to check" " `get_model_tflops_per_batch_per_gpu` of `VT5ForConditionalGeneration` if you change it." ) if extra_tokens_per_span_targets != 1: raise NotImplementedError( "extra_tokens_per_span_targets != 1 not supported yet. You need to check" " `get_model_tflops_per_batch_per_gpu` of `VT5ForConditionalGeneration` if you change it." ) def _tokens_length_to_inputs_length_targets_length(tokens_length): num_noise_tokens = int(round(tokens_length * noise_density)) num_nonnoise_tokens = tokens_length - num_noise_tokens num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length)) # inputs contain all nonnoise tokens, sentinels for all noise spans # and one EOS token. return ( num_nonnoise_tokens + num_noise_spans * extra_tokens_per_span_inputs + 1, num_noise_tokens + num_noise_spans * extra_tokens_per_span_targets + 1, ) tokens_length = inputs_length - 1 while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length: tokens_length += 1 inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length) # minor hack to get the targets length to be equal to inputs length # which is more likely to have been set to a nice round number. if noise_density == 0.5 and targets_length > inputs_length: tokens_length -= 1 targets_length -= 1 if verbose: logging.info( "tokens_length=%s inputs_length=%s targets_length=%s noise_density=%s mean_noise_span_length=%s ", tokens_length, inputs_length, targets_length, noise_density, mean_noise_span_length, ) return tokens_length, targets_length