VictorSanh's picture
Update visualization
217780a
raw
history blame
34.9 kB
import logging
from bisect import bisect_left
from collections import OrderedDict
import cv2
import numpy as np
import torch
from m4.training.utils import FAKE_TOKEN_AROUND_IMAGE_V2, IMAGE_TOKEN, _convert_to_rgb
logger = logging.getLogger(__name__)
# Hyper-parameters
_IMAGE_BONUS_VALUE = 2 # The bonus value for tokens preceding the image token
_MIN_LENGTH_DOCUMENTS_TO_PACK = (
5 # Minimum lengths of documents to pack together (lenghts is measures in number of tokens)
)
def incremental_to_binary_attention_mask(incremental_mask, num_classes=-1):
# This function converts: [-1, 0, 1] => [[0, 0], [1, 0], [0, 1]]
# If any of images index are more than num_classes, set them to -1.
# Words after the max number of images allowed have been seen don't attend on anything
if num_classes != -1:
incremental_mask[incremental_mask >= num_classes] = -1
negatives = incremental_mask == -1
incremental_mask[negatives] = 0
attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes)
attn_mask[negatives, :] = 0
return attn_mask
def image_attention_mask_for_packed_input_ids(input_ids, tokenizer):
image_attention_mask = torch.full_like(input_ids, fill_value=-1)
next_image_attention_mask = torch.full_like(input_ids, fill_value=-1)
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
eod_token_id = tokenizer.eos_token_id
for batch_idx in range(input_ids.size(0)):
count = -1
seen_eod = False
for idx, token_id in enumerate(input_ids[batch_idx]):
if token_id == image_token_id:
count += 1
image_attention_mask[batch_idx][idx] = count
seen_eod = False
else:
image_attention_mask[batch_idx][idx] = count
if seen_eod:
image_attention_mask[batch_idx][idx] = -1
if token_id == eod_token_id:
seen_eod = True
for batch_idx in range(input_ids.size(0)):
count = -1
seen_eod = False
for idx in range(input_ids[batch_idx].size(0) - 1, -1, -1):
token_id = input_ids[batch_idx][idx]
if token_id == image_token_id:
count += 1
next_image_attention_mask[batch_idx][idx] = count
seen_eod = False
else:
next_image_attention_mask[batch_idx][idx] = count
if token_id == eod_token_id:
seen_eod = True
if seen_eod:
next_image_attention_mask[batch_idx][idx] = -1
non_negative_indices = next_image_attention_mask[batch_idx] != -1
next_image_attention_mask[batch_idx][non_negative_indices] -= count
next_image_attention_mask[batch_idx][non_negative_indices] *= -1
return image_attention_mask, next_image_attention_mask
def laplacian_blur_detection(image, threshold=0.0):
# compute the Laplacian of the image and then return the focus
# measure, which is simply the variance of the Laplacian
if threshold == 0.0:
return False
image = np.array(image)
if len(image.shape) == 3 and image.shape[2] == 3:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
return cv2.Laplacian(gray, cv2.CV_64F).var() < threshold
else:
# Don't remove grayscale images
return False
def fft_blur_detection(image, size=50, threshold=0.0):
if threshold == 0.0:
return False
(h, w) = image.shape
(cX, cY) = (int(w / 2.0), int(h / 2.0))
fft = np.fft.fft2(image)
fftShift = np.fft.fftshift(fft)
fftShift[cY - size : cY + size, cX - size : cX + size] = 0
fftShift = np.fft.ifftshift(fftShift)
recon = np.fft.ifft2(fftShift)
magnitude = 20 * np.log(np.abs(recon))
mean = np.mean(magnitude)
return mean < threshold
def split_pack_and_pad(
sample,
tokenizer,
max_seq_len,
image_transform,
max_num_images,
max_num_samples_per_document=10,
prefix_seed=(0, 0),
is_blurred_fn=None,
blur_threshold=0.0,
add_begin_of_doc_token=False,
add_end_of_doc_token=True,
max_num_images_per_document=None,
):
"""
Return a batch of samples in the format expected by the model which
includes `input_ids`, `pixel_values`, `attention_mask`, `image_attention_mask`,
and `next_image_attention_mask`. The `input_ids` are sampled from the document to
ensure it has `max_seq_len` tokens otherwise, the shorter documents are packed together.
For each document, we sample a maximum of `max_num_samples_per_document` or `max_num_samples_for_curr_document`
(where the latter is proportional to the length of the document and inversely proportional to the length of subsequences)
`input_ids` with sequence length `max_seq_len` from the document. This means that
each sample sampled can have different start index. Based on the start index of sample that
has been sampled, we also sample a maximum of `max_num_images` images from the document.
If there are less than `max_num_images` images in the document, we pad the images with zeros.
The start indexes are skewed towards subsequences that contain images.
Args:
sample (Dict): A sample object containing the document with images and text.
tokenizer (PretrainedTokenizer): Text tokenizer to be used.
max_seq_len (int): Maximum sequence length of the returned text tokens.
image_transform (Callable): Transform to be applied on the images
max_num_images (int): Maximum number of images to be sampled per sample. If less, they are padded with zeros.
max_num_samples_per_document (int, optional): Maximum number of samples per document to be sampled. Defaults to 10.
prefix_seed: Prefix seed sequence for "reproducible randomness" in calls to `np.random.choice`
Returns:
_type_: _description_
"""
text_batch = sample["texts"]
image_batch = sample.get("image_embeddings", None)
is_raw_images = False
if image_batch is None:
image_batch = sample.get("images", None)
is_raw_images = True
if image_batch is None:
raise ValueError("Either image_embeddings or images must be present in the sample")
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
last_was_image = False
if is_blurred_fn is None:
is_blurred_fn = fft_blur_detection
all_images = []
all_texts = []
for raw_images, raw_texts in zip(image_batch, text_batch):
# Filter ones that don't have either one image and one text word
if not any(raw_images) or not any(raw_texts):
continue
if max_num_images_per_document:
num_images = sum([1 if image is not None else 0 for image in raw_images])
if num_images > max_num_images_per_document:
continue
any_blurred = False
if is_raw_images and blur_threshold > 0.0:
for image in raw_images:
if image is not None:
image = _convert_to_rgb(image)
any_blurred = any_blurred or is_blurred_fn(image, threshold=blur_threshold)
if any_blurred:
break
if any_blurred:
continue
inds_of_texts_to_split = [
i
for i, text in enumerate(raw_texts)
if text is not None and isinstance(text, str) and "END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED" in text
]
if inds_of_texts_to_split:
splitted_raw_images, splitted_raw_texts = [], []
previous_i = 0
for i in inds_of_texts_to_split:
splitting = raw_texts[i].split("END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED")
part1, part2 = splitting[0], splitting[-1]
sub_doc_images = raw_images[previous_i:i] + [None]
sub_doc_texts = raw_texts[previous_i:i] + [part1.strip()]
if not any(sub_doc_images): # This can happen if all images in raw_images[0:i] are all None
continue
splitted_raw_images.append(sub_doc_images)
splitted_raw_texts.append(sub_doc_texts)
if part2.strip() == "":
previous_i = i + 1
else:
raw_texts[i] = part2.strip()
previous_i = i
if previous_i < len(raw_images) and any(raw_images[previous_i:]):
splitted_raw_images.append(raw_images[previous_i:])
splitted_raw_texts.append(raw_texts[previous_i:])
else:
splitted_raw_images, splitted_raw_texts = [raw_images], [raw_texts]
# Sanity check
if [len(ims) for ims in splitted_raw_images] != [len(txts) for txts in splitted_raw_texts]:
raise ValueError(
"Number of images and texts don't match after splitting on `END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED`."
" Something core went wrong during the splitting and needs to be fixed."
)
for s_r_ims, s_r_txts in zip(splitted_raw_images, splitted_raw_texts):
images, web_text = [], ""
for image, text in zip(s_r_ims, s_r_txts):
if text is None and image is None:
continue
if image is not None:
web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{IMAGE_TOKEN}"
if is_raw_images:
images.append(image_transform(image))
else:
images.append(torch.tensor(image))
last_was_image = True
elif text is not None:
if last_was_image:
web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{text}"
last_was_image = False
else:
web_text += f" {text}" if web_text != "" else text
if last_was_image:
web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}"
web_text = web_text.strip(" ")
# This is mostly a sanity check. Cases like that should not happen at that point.
if web_text == "" or len(images) == 0:
continue
images = torch.stack(images)
all_images.append(images)
web_text_ids = tokenizer.encode(web_text, add_special_tokens=False)
if add_end_of_doc_token:
web_text_ids += [tokenizer.eos_token_id]
if add_begin_of_doc_token:
web_text_ids = [tokenizer.bos_token_id] + web_text_ids
all_texts.append(web_text_ids)
output_input_ids = []
output_images = []
output_attention_masks = []
output_num_images = []
output_num_text_tokens = []
input_ids_to_pack = []
images_to_pack = []
for images, text in zip(all_images, all_texts):
# We save all the documents which are shorter than the max_seq_len to pack them together.
if len(text) <= max_seq_len:
if len(text) < _MIN_LENGTH_DOCUMENTS_TO_PACK: # Filter out extremely short sequences
continue
input_ids_to_pack.extend(text)
images_to_pack.extend(images)
else:
# Computing the bonus scores for tokens near images to skew the sampling towards them
# The main idea is to give a bonus to tokens that are closely before an image token, so that these tokens have more chance to be sampled.
# Bonuses are computed for each image, which means a given token can receive bonuses from multiple images if this token is closely preceding multiple images.
# We sum all the bonuses and L1 normalized along the seq_len axis to get a probability distribution.
# Each token start with a regular bonus of 1, which corresponds to the uniform distribution over the sequence when there are no bonuses added.
# Now the remaining question is which precedding tokens do we distribue bonuses to.
# We first observe that for the sampled sub-sequence to be considered valid (i.e. sub-sequence contains an image), the start index can only be among [image_idx - max_seq_len + 1, image_idx].
# For the sake of the explanation, let's split the [image_idx - max_seq_len + 1, image_idx] interval in 3 parts: left, middle and right (in increasing order).
# If we give bonuses to the tokens just before the image (right part), then we are favoring p_next=0 because only the tokens after the image have an image to attend to.
# In practice, images will tend to be at the beginning of the sampled sub-sequence.
# If we give bonuses very far before the image (left part), then we are favoring p_next=1 because only the tokens before the image gave an image to attend to.
# In practice, images will tend to be at the end of the sampled sub-sequence.
# To avoid choosing favoring p_next=0 or p_next=1, we can give bonuses to the tokens in the middle part.
# In practise, images will tend to be in the middle of the sampled sequence.
# Ultimately, we don't want to skew the distribution fed to model in that way (i.e. whether images are in the beginning, middle or end of the sampled sub-sequence),
# and have all these cases represented equally in the data. So the easiest is to distribute a bonus to all of the max_seq_len tokens preceding the image.
all_scores = np.array([1] * len(text))
for img_token_idx in np.where(np.array(text) == image_token_id)[0]:
all_scores[max(0, img_token_idx - max_seq_len) : img_token_idx + 1] += _IMAGE_BONUS_VALUE
# all_scores = np.clip(all_scores, a_min=1, a_max=3 * _IMAGE_BONUS_VALUE * max_num_images + 1) # We can optionally clip the bonuses to avoid having too high values (i.e. outliers documents)
all_scores = all_scores[:-_MIN_LENGTH_DOCUMENTS_TO_PACK]
# The number of samples is proportional to the length of the text and inversely proportional to the maximum sequence length
max_num_samples_for_curr_document = len(text) // max_seq_len
# Set "reproducible randomness" by creating an np.default_rng seeded by (main seed, epoch, rank_idx, worker_idx, mapped_batch_index, text len)
choices = np.random.default_rng(seed=list(prefix_seed) + [len(text)]).choice(
range(len(text) - _MIN_LENGTH_DOCUMENTS_TO_PACK), # shorter sub-sequences are reserved for packing
min(
len(text) - max_seq_len, 2 * max_num_samples_per_document
), # Sampling more than necessary and then breaking out of the for loop once we have enough samples
p=all_scores / np.linalg.norm(all_scores, ord=1),
replace=False,
)
nb_effective_sequences_out_of_sampling = 0
for start_index in choices:
image_start_index = text[:start_index].count(image_token_id)
text_sub_sequence = text[start_index : start_index + max_seq_len]
image_count = text_sub_sequence.count(image_token_id)
if image_count == 0:
# Skip if there are no images in the sequence
continue
if len(text_sub_sequence) < max_seq_len:
# If the sub-sequence is shorter than max_seq_len, we reserve it for packing
# It necessarily mean that the sub-sequence was sampled towards the end of the document,
# which implies that we only need the `image_start_index` and not the `image_end_index`
if text_sub_sequence.count(image_token_id) != len(images[image_start_index:]):
# A safeguard for this
logger.warning(
"Skipping this sample because of mismatch in actual number of images and "
"the '<image>' tokens in the text"
)
continue
input_ids_to_pack.extend(text_sub_sequence)
images_to_pack.extend(images[image_start_index:])
continue
current_images = images[image_start_index : image_start_index + min(max_num_images, image_count)]
if len(current_images) != min(max_num_images, image_count):
# A safeguard for something off about this document, maybe `<image>` tag that
# by there from before or some issue in parsing the image?
logger.warning(
"Skipping this sample because of mismatch in actual number of images and "
"the '<image>' tokens in the text"
)
break
padded_image_tensor = torch.zeros(max_num_images, *images.size()[1:])
padded_image_tensor[: min(max_num_images, image_count)] = current_images
output_images.append(padded_image_tensor)
output_num_images.append(min(max_num_images, image_count))
output_input_ids.append(torch.tensor(text_sub_sequence))
output_num_text_tokens.append(len(text_sub_sequence))
attention_mask = torch.ones((max_seq_len,), dtype=torch.long)
output_attention_masks.append(attention_mask)
nb_effective_sequences_out_of_sampling += 1
if nb_effective_sequences_out_of_sampling >= min(
max_num_samples_for_curr_document, max_num_samples_per_document
):
# We got all the samples we need for this document, so breaking out
break
# Pack the remaining sequences from `input_ids_to_pack` x `images_to_pack`
if input_ids_to_pack:
image_counter = 0
for i in range(0, len(input_ids_to_pack), max_seq_len):
current_input_ids = input_ids_to_pack[i : i + max_seq_len]
unpadded_seq_len = len(current_input_ids)
num_images = current_input_ids.count(image_token_id)
if num_images == 0:
continue
current_images = images_to_pack[image_counter : image_counter + num_images]
image_counter += num_images
if unpadded_seq_len < max_seq_len:
padded_input_ids = [tokenizer.pad_token_id] * max_seq_len
padded_input_ids[:unpadded_seq_len] = current_input_ids
current_input_ids = padded_input_ids
elif unpadded_seq_len > max_seq_len:
# This case has no purpose other than safeguard
continue
try:
current_images = torch.stack(current_images)[:max_num_images]
except Exception:
continue
padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:])
padded_image_tensor[: current_images.size(0)] = current_images
attention_mask = torch.zeros((max_seq_len,), dtype=torch.long)
attention_mask[:unpadded_seq_len] = 1
output_images.append(padded_image_tensor)
output_input_ids.append(torch.tensor(current_input_ids))
output_num_text_tokens.append(unpadded_seq_len)
output_num_images.append(min(max_num_images, num_images))
output_attention_masks.append(attention_mask)
if len(output_images) == 0 or len(output_input_ids) == 0:
result = {
"input_ids": torch.tensor([], dtype=torch.long),
"attention_mask": torch.tensor([], dtype=torch.bool),
"image_attention_mask": torch.tensor([], dtype=torch.bool),
"next_image_attention_mask": torch.tensor([], dtype=torch.bool),
"num_images": torch.tensor([], dtype=torch.long),
"num_text_tokens": torch.tensor([], dtype=torch.long),
}
if is_raw_images:
result["pixel_values"] = torch.tensor([], dtype=torch.float32)
else:
result["image_embeddings"] = torch.tensor([], dtype=torch.float32)
return result
output_input_ids = torch.stack(output_input_ids)
output_images = torch.stack(output_images)
output_attention_masks = torch.stack(output_attention_masks)
image_attention_mask, next_image_attention_mask = image_attention_mask_for_packed_input_ids(
output_input_ids, tokenizer
)
image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=max_num_images)
next_image_attention_mask = incremental_to_binary_attention_mask(
next_image_attention_mask, num_classes=max_num_images
)
result = {
"input_ids": output_input_ids,
"attention_mask": output_attention_masks,
"image_attention_mask": image_attention_mask,
"next_image_attention_mask": next_image_attention_mask,
"num_images": torch.tensor(output_num_images),
"num_text_tokens": torch.tensor(output_num_text_tokens),
}
if is_raw_images:
result["pixel_values"] = output_images
else:
result["image_embeddings"] = output_images
return result
def split_and_pad_pmd(
sample,
tokenizer,
max_seq_len,
image_transform,
max_num_images,
prefix_seed=(0, 0),
is_blurred_fn=None,
blur_threshold=0.0,
prob_image_at_end=0.5, # If 1, the <image> token is always added at the end of the text
# If set to -1, all padding will be tolerated. If set to 0, no padding will be tolerated.
padding_tolerance=-1,
add_begin_of_doc_token=False,
add_end_of_doc_token=True,
):
if is_blurred_fn is None:
is_blurred_fn = fft_blur_detection
text_batch = sample["text"]
image_batch = sample.get("image_embedding", None)
is_raw_images = False
if image_batch is None:
image_batch = sample.get("image", None)
is_raw_images = True
filtered_image_batch = []
filtered_input_ids = []
# Define whether for the current PMD batch whether the images will be at the start or at the end.
rng = np.random.default_rng(seed=list(prefix_seed))
is_image_at_end = False
# rng.random is between 0 and 1, so if prob_image_at_end is 1, random value will
# always be less than `prob_image_at_end` and `is_image_at_end` will always be True.
# This means that images will always be at the end of the text.
if rng.random() < prob_image_at_end:
is_image_at_end = True
for image, text in zip(image_batch, text_batch):
if text is None or image is None:
continue
if is_raw_images and is_blurred_fn(image, threshold=blur_threshold):
continue
sample_text = f"{FAKE_TOKEN_AROUND_IMAGE_V2}{IMAGE_TOKEN}{FAKE_TOKEN_AROUND_IMAGE_V2}"
# Remove trailing and leading whitespaces, including newlines and tabs
text = text.strip()
if is_image_at_end:
sample_text = f"{text}{sample_text}"
else:
sample_text = f"{sample_text}{text}"
sample_input_ids = tokenizer.encode(sample_text, add_special_tokens=False)
if add_end_of_doc_token:
sample_input_ids += [tokenizer.eos_token_id]
if add_begin_of_doc_token:
sample_input_ids = [tokenizer.bos_token_id] + sample_input_ids
filtered_image_batch.append(image)
filtered_input_ids.append(sample_input_ids)
# sort by length of text and save same length elements in a mapping so we
# can retrieve candidates later.
filtered_image_batch, filtered_input_ids = zip(
*sorted(zip(filtered_image_batch, filtered_input_ids), key=lambda x: len(x[1]))
)
mapping_by_len = OrderedDict()
for i, sample_input_ids in enumerate(filtered_input_ids):
if len(sample_input_ids) not in mapping_by_len:
mapping_by_len[len(sample_input_ids)] = []
mapping_by_len[len(sample_input_ids)].append((filtered_image_batch[i], sample_input_ids))
all_images = []
all_texts = []
all_attention_masks = []
all_num_images = []
all_num_text_tokens = []
current_text = []
current_images = []
while True:
current_lens = list(mapping_by_len.keys())
if len(current_text) > 0:
# Now we try to do a binary search to find the biggest sequence that
# we can fit into the current sequence.
# This will eventually use up bigger sequences faster which is good
# and leave smaller sequences to pack with each other later.
diff = max_seq_len - len(current_text)
if len(current_lens) == 0:
possible_index = -1
else:
possible_index = bisect_left(current_lens, diff)
if possible_index == len(current_lens) or current_lens[possible_index] != diff:
possible_index -= 1
if possible_index >= 0:
best_possible_length = current_lens[possible_index]
image, sample_input_ids = mapping_by_len[best_possible_length].pop(0)
# If we have used up all the samples of a certain length, remove
# that length from the mapping.
if len(mapping_by_len[best_possible_length]) == 0:
del mapping_by_len[best_possible_length]
current_text.extend(sample_input_ids)
if is_raw_images:
current_images.append(image_transform(image))
else:
current_images.append(torch.tensor(image))
elif diff > padding_tolerance and padding_tolerance != -1:
# If we are here, it means that we still have padding left
# and we have exhausted our current unique options that will allow us to
# fill this sequence completely.
# So, we will try to fill the sequence with whatever we get from the unchanged
# copy of all sequences.
while diff > padding_tolerance:
# Find a random sequence to fit
# Why we need to add more stuff to prefix seed?
# prefix_seed will be same in the same batch which means that it might sample
# same thing again and again if there are multiple cases of padding in the
# same batch which means we need to make this part as random as possible.
rng = np.random.default_rng(
prefix_seed
+ (
diff,
len(current_text),
len(all_texts),
all_num_images,
)
)
choice = rng.choice(range(len(filtered_input_ids)))
image, sample_input_ids = filtered_image_batch[choice], filtered_input_ids[choice]
current_text.extend(sample_input_ids)
if is_raw_images:
current_images.append(image_transform(image))
else:
current_images.append(torch.tensor(image))
diff = max_seq_len - len(current_text)
# In the next top-level while loop iteration, this should go into the else
# clause which should also handle the sequences longer than max_seq_len
else:
current_images = torch.stack(current_images)
padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:])
padded_image_tensor[: current_images.size(0)] = current_images[
: min(max_num_images, current_images.size(0))
]
all_num_images.append(min(max_num_images, current_images.size(0)))
all_images.append(padded_image_tensor)
padded_input_ids = torch.full((max_seq_len,), tokenizer.pad_token_id)
current_max_len = min(max_seq_len, len(current_text))
padded_input_ids[:current_max_len] = torch.tensor(current_text)[:current_max_len]
all_num_text_tokens.append(current_max_len)
all_texts.append(padded_input_ids)
attention_mask = torch.zeros((max_seq_len,), dtype=torch.long)
attention_mask[: len(current_text)] = 1
all_attention_masks.append(attention_mask)
# Make sure to reset the current text and images.
current_images = []
current_text = []
if len(current_lens) == 0:
break
else:
# A case where we might not have any samples left over after the initial filtering step.
if len(current_lens) == 0:
break
image, sample_input_ids = mapping_by_len[current_lens[-1]].pop(0)
if len(mapping_by_len[current_lens[-1]]) == 0:
del mapping_by_len[current_lens[-1]]
current_text = sample_input_ids[:max_seq_len]
if is_raw_images:
current_images = [image_transform(image)]
else:
current_images = [torch.tensor(image)]
if len(all_images) == 0 or len(all_texts) == 0:
result = {
"input_ids": torch.tensor([], dtype=torch.long),
"attention_mask": torch.tensor([], dtype=torch.bool),
"image_attention_mask": torch.tensor([], dtype=torch.bool),
"num_images": torch.tensor([], dtype=torch.long),
"num_text_tokens": torch.tensor([], dtype=torch.long),
}
if is_raw_images:
result["pixel_values"] = torch.tensor([], dtype=torch.float32)
else:
result["image_embeddings"] = torch.tensor([], dtype=torch.float32)
return result
all_texts = torch.stack(all_texts)
all_images = torch.stack(all_images)
all_attention_masks = torch.stack(all_attention_masks)
image_attention_mask, next_image_attention_mask = image_attention_mask_for_packed_input_ids(all_texts, tokenizer)
image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=max_num_images)
next_image_attention_mask = incremental_to_binary_attention_mask(
next_image_attention_mask, num_classes=max_num_images
)
output = {
"input_ids": all_texts,
"attention_mask": all_attention_masks,
"image_attention_mask": image_attention_mask,
"num_images": torch.tensor(all_num_images),
"num_text_tokens": torch.tensor(all_num_text_tokens),
}
if is_raw_images:
output["pixel_values"] = all_images
else:
output["image_embeddings"] = all_images
if is_image_at_end:
# Set the correct attention mask based on whether the image is at the start
# or not. When it is at the end, we need next image attention mask.
output["image_attention_mask"] = next_image_attention_mask
return output
# Copied from https://github.com/google-research/text-to-text-transfer-transformer/blob/main/t5/data/preprocessors.py
def random_spans_helper(
inputs_length,
noise_density,
mean_noise_span_length,
extra_tokens_per_span_inputs,
extra_tokens_per_span_targets,
verbose=False,
):
"""Training parameters to avoid padding with random_spans_noise_mask.
When training a model with random_spans_noise_mask, we would like to set the
other training hyperparmeters in a way that avoids padding. This function
helps us compute these hyperparameters.
We assume that each noise span in the input is replaced by
extra_tokens_per_span_inputs sentinel tokens, and each non-noise span in the
targets is replaced by extra_tokens_per_span_targets sentinel tokens.
This function tells us the required number of tokens in the raw example (for
split_tokens()) as well as the length of the encoded targets.
Note that this function assumes the inputs and targets will have EOS appended
and includes that in the reported length.
Args:
inputs_length: an integer - desired length of the tokenized inputs sequence
noise_density: a float
mean_noise_span_length: a float
extra_tokens_per_span_inputs: an integer
extra_tokens_per_span_targets: an integer
verbose: a bool indicating whether to log sequence lengths
Returns:
tokens_length: length of original text in tokens
targets_length: an integer - length in tokens of encoded targets sequence
"""
if extra_tokens_per_span_inputs != 1:
raise NotImplementedError(
"extra_tokens_per_span_inputs != 1 not supported yet. You need to check"
" `get_model_tflops_per_batch_per_gpu` of `VT5ForConditionalGeneration` if you change it."
)
if extra_tokens_per_span_targets != 1:
raise NotImplementedError(
"extra_tokens_per_span_targets != 1 not supported yet. You need to check"
" `get_model_tflops_per_batch_per_gpu` of `VT5ForConditionalGeneration` if you change it."
)
def _tokens_length_to_inputs_length_targets_length(tokens_length):
num_noise_tokens = int(round(tokens_length * noise_density))
num_nonnoise_tokens = tokens_length - num_noise_tokens
num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
# inputs contain all nonnoise tokens, sentinels for all noise spans
# and one EOS token.
return (
num_nonnoise_tokens + num_noise_spans * extra_tokens_per_span_inputs + 1,
num_noise_tokens + num_noise_spans * extra_tokens_per_span_targets + 1,
)
tokens_length = inputs_length - 1
while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:
tokens_length += 1
inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)
# minor hack to get the targets length to be equal to inputs length
# which is more likely to have been set to a nice round number.
if noise_density == 0.5 and targets_length > inputs_length:
tokens_length -= 1
targets_length -= 1
if verbose:
logging.info(
"tokens_length=%s inputs_length=%s targets_length=%s noise_density=%s mean_noise_span_length=%s ",
tokens_length,
inputs_length,
targets_length,
noise_density,
mean_noise_span_length,
)
return tokens_length, targets_length