Spaces:
Build error
Build error
import logging | |
from bisect import bisect_left | |
from collections import OrderedDict | |
import cv2 | |
import numpy as np | |
import torch | |
from m4.training.utils import FAKE_TOKEN_AROUND_IMAGE_V2, IMAGE_TOKEN, _convert_to_rgb | |
logger = logging.getLogger(__name__) | |
# Hyper-parameters | |
_IMAGE_BONUS_VALUE = 2 # The bonus value for tokens preceding the image token | |
_MIN_LENGTH_DOCUMENTS_TO_PACK = ( | |
5 # Minimum lengths of documents to pack together (lenghts is measures in number of tokens) | |
) | |
def incremental_to_binary_attention_mask(incremental_mask, num_classes=-1): | |
# This function converts: [-1, 0, 1] => [[0, 0], [1, 0], [0, 1]] | |
# If any of images index are more than num_classes, set them to -1. | |
# Words after the max number of images allowed have been seen don't attend on anything | |
if num_classes != -1: | |
incremental_mask[incremental_mask >= num_classes] = -1 | |
negatives = incremental_mask == -1 | |
incremental_mask[negatives] = 0 | |
attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes) | |
attn_mask[negatives, :] = 0 | |
return attn_mask | |
def image_attention_mask_for_packed_input_ids(input_ids, tokenizer): | |
image_attention_mask = torch.full_like(input_ids, fill_value=-1) | |
next_image_attention_mask = torch.full_like(input_ids, fill_value=-1) | |
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) | |
eod_token_id = tokenizer.eos_token_id | |
for batch_idx in range(input_ids.size(0)): | |
count = -1 | |
seen_eod = False | |
for idx, token_id in enumerate(input_ids[batch_idx]): | |
if token_id == image_token_id: | |
count += 1 | |
image_attention_mask[batch_idx][idx] = count | |
seen_eod = False | |
else: | |
image_attention_mask[batch_idx][idx] = count | |
if seen_eod: | |
image_attention_mask[batch_idx][idx] = -1 | |
if token_id == eod_token_id: | |
seen_eod = True | |
for batch_idx in range(input_ids.size(0)): | |
count = -1 | |
seen_eod = False | |
for idx in range(input_ids[batch_idx].size(0) - 1, -1, -1): | |
token_id = input_ids[batch_idx][idx] | |
if token_id == image_token_id: | |
count += 1 | |
next_image_attention_mask[batch_idx][idx] = count | |
seen_eod = False | |
else: | |
next_image_attention_mask[batch_idx][idx] = count | |
if token_id == eod_token_id: | |
seen_eod = True | |
if seen_eod: | |
next_image_attention_mask[batch_idx][idx] = -1 | |
non_negative_indices = next_image_attention_mask[batch_idx] != -1 | |
next_image_attention_mask[batch_idx][non_negative_indices] -= count | |
next_image_attention_mask[batch_idx][non_negative_indices] *= -1 | |
return image_attention_mask, next_image_attention_mask | |
def laplacian_blur_detection(image, threshold=0.0): | |
# compute the Laplacian of the image and then return the focus | |
# measure, which is simply the variance of the Laplacian | |
if threshold == 0.0: | |
return False | |
image = np.array(image) | |
if len(image.shape) == 3 and image.shape[2] == 3: | |
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
return cv2.Laplacian(gray, cv2.CV_64F).var() < threshold | |
else: | |
# Don't remove grayscale images | |
return False | |
def fft_blur_detection(image, size=50, threshold=0.0): | |
if threshold == 0.0: | |
return False | |
(h, w) = image.shape | |
(cX, cY) = (int(w / 2.0), int(h / 2.0)) | |
fft = np.fft.fft2(image) | |
fftShift = np.fft.fftshift(fft) | |
fftShift[cY - size : cY + size, cX - size : cX + size] = 0 | |
fftShift = np.fft.ifftshift(fftShift) | |
recon = np.fft.ifft2(fftShift) | |
magnitude = 20 * np.log(np.abs(recon)) | |
mean = np.mean(magnitude) | |
return mean < threshold | |
def split_pack_and_pad( | |
sample, | |
tokenizer, | |
max_seq_len, | |
image_transform, | |
max_num_images, | |
max_num_samples_per_document=10, | |
prefix_seed=(0, 0), | |
is_blurred_fn=None, | |
blur_threshold=0.0, | |
add_begin_of_doc_token=False, | |
add_end_of_doc_token=True, | |
max_num_images_per_document=None, | |
): | |
""" | |
Return a batch of samples in the format expected by the model which | |
includes `input_ids`, `pixel_values`, `attention_mask`, `image_attention_mask`, | |
and `next_image_attention_mask`. The `input_ids` are sampled from the document to | |
ensure it has `max_seq_len` tokens otherwise, the shorter documents are packed together. | |
For each document, we sample a maximum of `max_num_samples_per_document` or `max_num_samples_for_curr_document` | |
(where the latter is proportional to the length of the document and inversely proportional to the length of subsequences) | |
`input_ids` with sequence length `max_seq_len` from the document. This means that | |
each sample sampled can have different start index. Based on the start index of sample that | |
has been sampled, we also sample a maximum of `max_num_images` images from the document. | |
If there are less than `max_num_images` images in the document, we pad the images with zeros. | |
The start indexes are skewed towards subsequences that contain images. | |
Args: | |
sample (Dict): A sample object containing the document with images and text. | |
tokenizer (PretrainedTokenizer): Text tokenizer to be used. | |
max_seq_len (int): Maximum sequence length of the returned text tokens. | |
image_transform (Callable): Transform to be applied on the images | |
max_num_images (int): Maximum number of images to be sampled per sample. If less, they are padded with zeros. | |
max_num_samples_per_document (int, optional): Maximum number of samples per document to be sampled. Defaults to 10. | |
prefix_seed: Prefix seed sequence for "reproducible randomness" in calls to `np.random.choice` | |
Returns: | |
_type_: _description_ | |
""" | |
text_batch = sample["texts"] | |
image_batch = sample.get("image_embeddings", None) | |
is_raw_images = False | |
if image_batch is None: | |
image_batch = sample.get("images", None) | |
is_raw_images = True | |
if image_batch is None: | |
raise ValueError("Either image_embeddings or images must be present in the sample") | |
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) | |
last_was_image = False | |
if is_blurred_fn is None: | |
is_blurred_fn = fft_blur_detection | |
all_images = [] | |
all_texts = [] | |
for raw_images, raw_texts in zip(image_batch, text_batch): | |
# Filter ones that don't have either one image and one text word | |
if not any(raw_images) or not any(raw_texts): | |
continue | |
if max_num_images_per_document: | |
num_images = sum([1 if image is not None else 0 for image in raw_images]) | |
if num_images > max_num_images_per_document: | |
continue | |
any_blurred = False | |
if is_raw_images and blur_threshold > 0.0: | |
for image in raw_images: | |
if image is not None: | |
image = _convert_to_rgb(image) | |
any_blurred = any_blurred or is_blurred_fn(image, threshold=blur_threshold) | |
if any_blurred: | |
break | |
if any_blurred: | |
continue | |
inds_of_texts_to_split = [ | |
i | |
for i, text in enumerate(raw_texts) | |
if text is not None and isinstance(text, str) and "END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED" in text | |
] | |
if inds_of_texts_to_split: | |
splitted_raw_images, splitted_raw_texts = [], [] | |
previous_i = 0 | |
for i in inds_of_texts_to_split: | |
splitting = raw_texts[i].split("END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED") | |
part1, part2 = splitting[0], splitting[-1] | |
sub_doc_images = raw_images[previous_i:i] + [None] | |
sub_doc_texts = raw_texts[previous_i:i] + [part1.strip()] | |
if not any(sub_doc_images): # This can happen if all images in raw_images[0:i] are all None | |
continue | |
splitted_raw_images.append(sub_doc_images) | |
splitted_raw_texts.append(sub_doc_texts) | |
if part2.strip() == "": | |
previous_i = i + 1 | |
else: | |
raw_texts[i] = part2.strip() | |
previous_i = i | |
if previous_i < len(raw_images) and any(raw_images[previous_i:]): | |
splitted_raw_images.append(raw_images[previous_i:]) | |
splitted_raw_texts.append(raw_texts[previous_i:]) | |
else: | |
splitted_raw_images, splitted_raw_texts = [raw_images], [raw_texts] | |
# Sanity check | |
if [len(ims) for ims in splitted_raw_images] != [len(txts) for txts in splitted_raw_texts]: | |
raise ValueError( | |
"Number of images and texts don't match after splitting on `END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED`." | |
" Something core went wrong during the splitting and needs to be fixed." | |
) | |
for s_r_ims, s_r_txts in zip(splitted_raw_images, splitted_raw_texts): | |
images, web_text = [], "" | |
for image, text in zip(s_r_ims, s_r_txts): | |
if text is None and image is None: | |
continue | |
if image is not None: | |
web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{IMAGE_TOKEN}" | |
if is_raw_images: | |
images.append(image_transform(image)) | |
else: | |
images.append(torch.tensor(image)) | |
last_was_image = True | |
elif text is not None: | |
if last_was_image: | |
web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{text}" | |
last_was_image = False | |
else: | |
web_text += f" {text}" if web_text != "" else text | |
if last_was_image: | |
web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}" | |
web_text = web_text.strip(" ") | |
# This is mostly a sanity check. Cases like that should not happen at that point. | |
if web_text == "" or len(images) == 0: | |
continue | |
images = torch.stack(images) | |
all_images.append(images) | |
web_text_ids = tokenizer.encode(web_text, add_special_tokens=False) | |
if add_end_of_doc_token: | |
web_text_ids += [tokenizer.eos_token_id] | |
if add_begin_of_doc_token: | |
web_text_ids = [tokenizer.bos_token_id] + web_text_ids | |
all_texts.append(web_text_ids) | |
output_input_ids = [] | |
output_images = [] | |
output_attention_masks = [] | |
output_num_images = [] | |
output_num_text_tokens = [] | |
input_ids_to_pack = [] | |
images_to_pack = [] | |
for images, text in zip(all_images, all_texts): | |
# We save all the documents which are shorter than the max_seq_len to pack them together. | |
if len(text) <= max_seq_len: | |
if len(text) < _MIN_LENGTH_DOCUMENTS_TO_PACK: # Filter out extremely short sequences | |
continue | |
input_ids_to_pack.extend(text) | |
images_to_pack.extend(images) | |
else: | |
# Computing the bonus scores for tokens near images to skew the sampling towards them | |
# The main idea is to give a bonus to tokens that are closely before an image token, so that these tokens have more chance to be sampled. | |
# Bonuses are computed for each image, which means a given token can receive bonuses from multiple images if this token is closely preceding multiple images. | |
# We sum all the bonuses and L1 normalized along the seq_len axis to get a probability distribution. | |
# Each token start with a regular bonus of 1, which corresponds to the uniform distribution over the sequence when there are no bonuses added. | |
# Now the remaining question is which precedding tokens do we distribue bonuses to. | |
# We first observe that for the sampled sub-sequence to be considered valid (i.e. sub-sequence contains an image), the start index can only be among [image_idx - max_seq_len + 1, image_idx]. | |
# For the sake of the explanation, let's split the [image_idx - max_seq_len + 1, image_idx] interval in 3 parts: left, middle and right (in increasing order). | |
# If we give bonuses to the tokens just before the image (right part), then we are favoring p_next=0 because only the tokens after the image have an image to attend to. | |
# In practice, images will tend to be at the beginning of the sampled sub-sequence. | |
# If we give bonuses very far before the image (left part), then we are favoring p_next=1 because only the tokens before the image gave an image to attend to. | |
# In practice, images will tend to be at the end of the sampled sub-sequence. | |
# To avoid choosing favoring p_next=0 or p_next=1, we can give bonuses to the tokens in the middle part. | |
# In practise, images will tend to be in the middle of the sampled sequence. | |
# Ultimately, we don't want to skew the distribution fed to model in that way (i.e. whether images are in the beginning, middle or end of the sampled sub-sequence), | |
# and have all these cases represented equally in the data. So the easiest is to distribute a bonus to all of the max_seq_len tokens preceding the image. | |
all_scores = np.array([1] * len(text)) | |
for img_token_idx in np.where(np.array(text) == image_token_id)[0]: | |
all_scores[max(0, img_token_idx - max_seq_len) : img_token_idx + 1] += _IMAGE_BONUS_VALUE | |
# all_scores = np.clip(all_scores, a_min=1, a_max=3 * _IMAGE_BONUS_VALUE * max_num_images + 1) # We can optionally clip the bonuses to avoid having too high values (i.e. outliers documents) | |
all_scores = all_scores[:-_MIN_LENGTH_DOCUMENTS_TO_PACK] | |
# The number of samples is proportional to the length of the text and inversely proportional to the maximum sequence length | |
max_num_samples_for_curr_document = len(text) // max_seq_len | |
# Set "reproducible randomness" by creating an np.default_rng seeded by (main seed, epoch, rank_idx, worker_idx, mapped_batch_index, text len) | |
choices = np.random.default_rng(seed=list(prefix_seed) + [len(text)]).choice( | |
range(len(text) - _MIN_LENGTH_DOCUMENTS_TO_PACK), # shorter sub-sequences are reserved for packing | |
min( | |
len(text) - max_seq_len, 2 * max_num_samples_per_document | |
), # Sampling more than necessary and then breaking out of the for loop once we have enough samples | |
p=all_scores / np.linalg.norm(all_scores, ord=1), | |
replace=False, | |
) | |
nb_effective_sequences_out_of_sampling = 0 | |
for start_index in choices: | |
image_start_index = text[:start_index].count(image_token_id) | |
text_sub_sequence = text[start_index : start_index + max_seq_len] | |
image_count = text_sub_sequence.count(image_token_id) | |
if image_count == 0: | |
# Skip if there are no images in the sequence | |
continue | |
if len(text_sub_sequence) < max_seq_len: | |
# If the sub-sequence is shorter than max_seq_len, we reserve it for packing | |
# It necessarily mean that the sub-sequence was sampled towards the end of the document, | |
# which implies that we only need the `image_start_index` and not the `image_end_index` | |
if text_sub_sequence.count(image_token_id) != len(images[image_start_index:]): | |
# A safeguard for this | |
logger.warning( | |
"Skipping this sample because of mismatch in actual number of images and " | |
"the '<image>' tokens in the text" | |
) | |
continue | |
input_ids_to_pack.extend(text_sub_sequence) | |
images_to_pack.extend(images[image_start_index:]) | |
continue | |
current_images = images[image_start_index : image_start_index + min(max_num_images, image_count)] | |
if len(current_images) != min(max_num_images, image_count): | |
# A safeguard for something off about this document, maybe `<image>` tag that | |
# by there from before or some issue in parsing the image? | |
logger.warning( | |
"Skipping this sample because of mismatch in actual number of images and " | |
"the '<image>' tokens in the text" | |
) | |
break | |
padded_image_tensor = torch.zeros(max_num_images, *images.size()[1:]) | |
padded_image_tensor[: min(max_num_images, image_count)] = current_images | |
output_images.append(padded_image_tensor) | |
output_num_images.append(min(max_num_images, image_count)) | |
output_input_ids.append(torch.tensor(text_sub_sequence)) | |
output_num_text_tokens.append(len(text_sub_sequence)) | |
attention_mask = torch.ones((max_seq_len,), dtype=torch.long) | |
output_attention_masks.append(attention_mask) | |
nb_effective_sequences_out_of_sampling += 1 | |
if nb_effective_sequences_out_of_sampling >= min( | |
max_num_samples_for_curr_document, max_num_samples_per_document | |
): | |
# We got all the samples we need for this document, so breaking out | |
break | |
# Pack the remaining sequences from `input_ids_to_pack` x `images_to_pack` | |
if input_ids_to_pack: | |
image_counter = 0 | |
for i in range(0, len(input_ids_to_pack), max_seq_len): | |
current_input_ids = input_ids_to_pack[i : i + max_seq_len] | |
unpadded_seq_len = len(current_input_ids) | |
num_images = current_input_ids.count(image_token_id) | |
if num_images == 0: | |
continue | |
current_images = images_to_pack[image_counter : image_counter + num_images] | |
image_counter += num_images | |
if unpadded_seq_len < max_seq_len: | |
padded_input_ids = [tokenizer.pad_token_id] * max_seq_len | |
padded_input_ids[:unpadded_seq_len] = current_input_ids | |
current_input_ids = padded_input_ids | |
elif unpadded_seq_len > max_seq_len: | |
# This case has no purpose other than safeguard | |
continue | |
try: | |
current_images = torch.stack(current_images)[:max_num_images] | |
except Exception: | |
continue | |
padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:]) | |
padded_image_tensor[: current_images.size(0)] = current_images | |
attention_mask = torch.zeros((max_seq_len,), dtype=torch.long) | |
attention_mask[:unpadded_seq_len] = 1 | |
output_images.append(padded_image_tensor) | |
output_input_ids.append(torch.tensor(current_input_ids)) | |
output_num_text_tokens.append(unpadded_seq_len) | |
output_num_images.append(min(max_num_images, num_images)) | |
output_attention_masks.append(attention_mask) | |
if len(output_images) == 0 or len(output_input_ids) == 0: | |
result = { | |
"input_ids": torch.tensor([], dtype=torch.long), | |
"attention_mask": torch.tensor([], dtype=torch.bool), | |
"image_attention_mask": torch.tensor([], dtype=torch.bool), | |
"next_image_attention_mask": torch.tensor([], dtype=torch.bool), | |
"num_images": torch.tensor([], dtype=torch.long), | |
"num_text_tokens": torch.tensor([], dtype=torch.long), | |
} | |
if is_raw_images: | |
result["pixel_values"] = torch.tensor([], dtype=torch.float32) | |
else: | |
result["image_embeddings"] = torch.tensor([], dtype=torch.float32) | |
return result | |
output_input_ids = torch.stack(output_input_ids) | |
output_images = torch.stack(output_images) | |
output_attention_masks = torch.stack(output_attention_masks) | |
image_attention_mask, next_image_attention_mask = image_attention_mask_for_packed_input_ids( | |
output_input_ids, tokenizer | |
) | |
image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=max_num_images) | |
next_image_attention_mask = incremental_to_binary_attention_mask( | |
next_image_attention_mask, num_classes=max_num_images | |
) | |
result = { | |
"input_ids": output_input_ids, | |
"attention_mask": output_attention_masks, | |
"image_attention_mask": image_attention_mask, | |
"next_image_attention_mask": next_image_attention_mask, | |
"num_images": torch.tensor(output_num_images), | |
"num_text_tokens": torch.tensor(output_num_text_tokens), | |
} | |
if is_raw_images: | |
result["pixel_values"] = output_images | |
else: | |
result["image_embeddings"] = output_images | |
return result | |
def split_and_pad_pmd( | |
sample, | |
tokenizer, | |
max_seq_len, | |
image_transform, | |
max_num_images, | |
prefix_seed=(0, 0), | |
is_blurred_fn=None, | |
blur_threshold=0.0, | |
prob_image_at_end=0.5, # If 1, the <image> token is always added at the end of the text | |
# If set to -1, all padding will be tolerated. If set to 0, no padding will be tolerated. | |
padding_tolerance=-1, | |
add_begin_of_doc_token=False, | |
add_end_of_doc_token=True, | |
): | |
if is_blurred_fn is None: | |
is_blurred_fn = fft_blur_detection | |
text_batch = sample["text"] | |
image_batch = sample.get("image_embedding", None) | |
is_raw_images = False | |
if image_batch is None: | |
image_batch = sample.get("image", None) | |
is_raw_images = True | |
filtered_image_batch = [] | |
filtered_input_ids = [] | |
# Define whether for the current PMD batch whether the images will be at the start or at the end. | |
rng = np.random.default_rng(seed=list(prefix_seed)) | |
is_image_at_end = False | |
# rng.random is between 0 and 1, so if prob_image_at_end is 1, random value will | |
# always be less than `prob_image_at_end` and `is_image_at_end` will always be True. | |
# This means that images will always be at the end of the text. | |
if rng.random() < prob_image_at_end: | |
is_image_at_end = True | |
for image, text in zip(image_batch, text_batch): | |
if text is None or image is None: | |
continue | |
if is_raw_images and is_blurred_fn(image, threshold=blur_threshold): | |
continue | |
sample_text = f"{FAKE_TOKEN_AROUND_IMAGE_V2}{IMAGE_TOKEN}{FAKE_TOKEN_AROUND_IMAGE_V2}" | |
# Remove trailing and leading whitespaces, including newlines and tabs | |
text = text.strip() | |
if is_image_at_end: | |
sample_text = f"{text}{sample_text}" | |
else: | |
sample_text = f"{sample_text}{text}" | |
sample_input_ids = tokenizer.encode(sample_text, add_special_tokens=False) | |
if add_end_of_doc_token: | |
sample_input_ids += [tokenizer.eos_token_id] | |
if add_begin_of_doc_token: | |
sample_input_ids = [tokenizer.bos_token_id] + sample_input_ids | |
filtered_image_batch.append(image) | |
filtered_input_ids.append(sample_input_ids) | |
# sort by length of text and save same length elements in a mapping so we | |
# can retrieve candidates later. | |
filtered_image_batch, filtered_input_ids = zip( | |
*sorted(zip(filtered_image_batch, filtered_input_ids), key=lambda x: len(x[1])) | |
) | |
mapping_by_len = OrderedDict() | |
for i, sample_input_ids in enumerate(filtered_input_ids): | |
if len(sample_input_ids) not in mapping_by_len: | |
mapping_by_len[len(sample_input_ids)] = [] | |
mapping_by_len[len(sample_input_ids)].append((filtered_image_batch[i], sample_input_ids)) | |
all_images = [] | |
all_texts = [] | |
all_attention_masks = [] | |
all_num_images = [] | |
all_num_text_tokens = [] | |
current_text = [] | |
current_images = [] | |
while True: | |
current_lens = list(mapping_by_len.keys()) | |
if len(current_text) > 0: | |
# Now we try to do a binary search to find the biggest sequence that | |
# we can fit into the current sequence. | |
# This will eventually use up bigger sequences faster which is good | |
# and leave smaller sequences to pack with each other later. | |
diff = max_seq_len - len(current_text) | |
if len(current_lens) == 0: | |
possible_index = -1 | |
else: | |
possible_index = bisect_left(current_lens, diff) | |
if possible_index == len(current_lens) or current_lens[possible_index] != diff: | |
possible_index -= 1 | |
if possible_index >= 0: | |
best_possible_length = current_lens[possible_index] | |
image, sample_input_ids = mapping_by_len[best_possible_length].pop(0) | |
# If we have used up all the samples of a certain length, remove | |
# that length from the mapping. | |
if len(mapping_by_len[best_possible_length]) == 0: | |
del mapping_by_len[best_possible_length] | |
current_text.extend(sample_input_ids) | |
if is_raw_images: | |
current_images.append(image_transform(image)) | |
else: | |
current_images.append(torch.tensor(image)) | |
elif diff > padding_tolerance and padding_tolerance != -1: | |
# If we are here, it means that we still have padding left | |
# and we have exhausted our current unique options that will allow us to | |
# fill this sequence completely. | |
# So, we will try to fill the sequence with whatever we get from the unchanged | |
# copy of all sequences. | |
while diff > padding_tolerance: | |
# Find a random sequence to fit | |
# Why we need to add more stuff to prefix seed? | |
# prefix_seed will be same in the same batch which means that it might sample | |
# same thing again and again if there are multiple cases of padding in the | |
# same batch which means we need to make this part as random as possible. | |
rng = np.random.default_rng( | |
prefix_seed | |
+ ( | |
diff, | |
len(current_text), | |
len(all_texts), | |
all_num_images, | |
) | |
) | |
choice = rng.choice(range(len(filtered_input_ids))) | |
image, sample_input_ids = filtered_image_batch[choice], filtered_input_ids[choice] | |
current_text.extend(sample_input_ids) | |
if is_raw_images: | |
current_images.append(image_transform(image)) | |
else: | |
current_images.append(torch.tensor(image)) | |
diff = max_seq_len - len(current_text) | |
# In the next top-level while loop iteration, this should go into the else | |
# clause which should also handle the sequences longer than max_seq_len | |
else: | |
current_images = torch.stack(current_images) | |
padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:]) | |
padded_image_tensor[: current_images.size(0)] = current_images[ | |
: min(max_num_images, current_images.size(0)) | |
] | |
all_num_images.append(min(max_num_images, current_images.size(0))) | |
all_images.append(padded_image_tensor) | |
padded_input_ids = torch.full((max_seq_len,), tokenizer.pad_token_id) | |
current_max_len = min(max_seq_len, len(current_text)) | |
padded_input_ids[:current_max_len] = torch.tensor(current_text)[:current_max_len] | |
all_num_text_tokens.append(current_max_len) | |
all_texts.append(padded_input_ids) | |
attention_mask = torch.zeros((max_seq_len,), dtype=torch.long) | |
attention_mask[: len(current_text)] = 1 | |
all_attention_masks.append(attention_mask) | |
# Make sure to reset the current text and images. | |
current_images = [] | |
current_text = [] | |
if len(current_lens) == 0: | |
break | |
else: | |
# A case where we might not have any samples left over after the initial filtering step. | |
if len(current_lens) == 0: | |
break | |
image, sample_input_ids = mapping_by_len[current_lens[-1]].pop(0) | |
if len(mapping_by_len[current_lens[-1]]) == 0: | |
del mapping_by_len[current_lens[-1]] | |
current_text = sample_input_ids[:max_seq_len] | |
if is_raw_images: | |
current_images = [image_transform(image)] | |
else: | |
current_images = [torch.tensor(image)] | |
if len(all_images) == 0 or len(all_texts) == 0: | |
result = { | |
"input_ids": torch.tensor([], dtype=torch.long), | |
"attention_mask": torch.tensor([], dtype=torch.bool), | |
"image_attention_mask": torch.tensor([], dtype=torch.bool), | |
"num_images": torch.tensor([], dtype=torch.long), | |
"num_text_tokens": torch.tensor([], dtype=torch.long), | |
} | |
if is_raw_images: | |
result["pixel_values"] = torch.tensor([], dtype=torch.float32) | |
else: | |
result["image_embeddings"] = torch.tensor([], dtype=torch.float32) | |
return result | |
all_texts = torch.stack(all_texts) | |
all_images = torch.stack(all_images) | |
all_attention_masks = torch.stack(all_attention_masks) | |
image_attention_mask, next_image_attention_mask = image_attention_mask_for_packed_input_ids(all_texts, tokenizer) | |
image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=max_num_images) | |
next_image_attention_mask = incremental_to_binary_attention_mask( | |
next_image_attention_mask, num_classes=max_num_images | |
) | |
output = { | |
"input_ids": all_texts, | |
"attention_mask": all_attention_masks, | |
"image_attention_mask": image_attention_mask, | |
"num_images": torch.tensor(all_num_images), | |
"num_text_tokens": torch.tensor(all_num_text_tokens), | |
} | |
if is_raw_images: | |
output["pixel_values"] = all_images | |
else: | |
output["image_embeddings"] = all_images | |
if is_image_at_end: | |
# Set the correct attention mask based on whether the image is at the start | |
# or not. When it is at the end, we need next image attention mask. | |
output["image_attention_mask"] = next_image_attention_mask | |
return output | |
# Copied from https://github.com/google-research/text-to-text-transfer-transformer/blob/main/t5/data/preprocessors.py | |
def random_spans_helper( | |
inputs_length, | |
noise_density, | |
mean_noise_span_length, | |
extra_tokens_per_span_inputs, | |
extra_tokens_per_span_targets, | |
verbose=False, | |
): | |
"""Training parameters to avoid padding with random_spans_noise_mask. | |
When training a model with random_spans_noise_mask, we would like to set the | |
other training hyperparmeters in a way that avoids padding. This function | |
helps us compute these hyperparameters. | |
We assume that each noise span in the input is replaced by | |
extra_tokens_per_span_inputs sentinel tokens, and each non-noise span in the | |
targets is replaced by extra_tokens_per_span_targets sentinel tokens. | |
This function tells us the required number of tokens in the raw example (for | |
split_tokens()) as well as the length of the encoded targets. | |
Note that this function assumes the inputs and targets will have EOS appended | |
and includes that in the reported length. | |
Args: | |
inputs_length: an integer - desired length of the tokenized inputs sequence | |
noise_density: a float | |
mean_noise_span_length: a float | |
extra_tokens_per_span_inputs: an integer | |
extra_tokens_per_span_targets: an integer | |
verbose: a bool indicating whether to log sequence lengths | |
Returns: | |
tokens_length: length of original text in tokens | |
targets_length: an integer - length in tokens of encoded targets sequence | |
""" | |
if extra_tokens_per_span_inputs != 1: | |
raise NotImplementedError( | |
"extra_tokens_per_span_inputs != 1 not supported yet. You need to check" | |
" `get_model_tflops_per_batch_per_gpu` of `VT5ForConditionalGeneration` if you change it." | |
) | |
if extra_tokens_per_span_targets != 1: | |
raise NotImplementedError( | |
"extra_tokens_per_span_targets != 1 not supported yet. You need to check" | |
" `get_model_tflops_per_batch_per_gpu` of `VT5ForConditionalGeneration` if you change it." | |
) | |
def _tokens_length_to_inputs_length_targets_length(tokens_length): | |
num_noise_tokens = int(round(tokens_length * noise_density)) | |
num_nonnoise_tokens = tokens_length - num_noise_tokens | |
num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length)) | |
# inputs contain all nonnoise tokens, sentinels for all noise spans | |
# and one EOS token. | |
return ( | |
num_nonnoise_tokens + num_noise_spans * extra_tokens_per_span_inputs + 1, | |
num_noise_tokens + num_noise_spans * extra_tokens_per_span_targets + 1, | |
) | |
tokens_length = inputs_length - 1 | |
while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length: | |
tokens_length += 1 | |
inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length) | |
# minor hack to get the targets length to be equal to inputs length | |
# which is more likely to have been set to a nice round number. | |
if noise_density == 0.5 and targets_length > inputs_length: | |
tokens_length -= 1 | |
targets_length -= 1 | |
if verbose: | |
logging.info( | |
"tokens_length=%s inputs_length=%s targets_length=%s noise_density=%s mean_noise_span_length=%s ", | |
tokens_length, | |
inputs_length, | |
targets_length, | |
noise_density, | |
mean_noise_span_length, | |
) | |
return tokens_length, targets_length | |