Spaces:
Build error
Build error
| import logging | |
| from bisect import bisect_left | |
| from collections import OrderedDict | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from m4.training.utils import FAKE_TOKEN_AROUND_IMAGE_V2, IMAGE_TOKEN, _convert_to_rgb | |
| logger = logging.getLogger(__name__) | |
| # Hyper-parameters | |
| _IMAGE_BONUS_VALUE = 2 # The bonus value for tokens preceding the image token | |
| _MIN_LENGTH_DOCUMENTS_TO_PACK = ( | |
| 5 # Minimum lengths of documents to pack together (lenghts is measures in number of tokens) | |
| ) | |
| def incremental_to_binary_attention_mask(incremental_mask, num_classes=-1): | |
| # This function converts: [-1, 0, 1] => [[0, 0], [1, 0], [0, 1]] | |
| # If any of images index are more than num_classes, set them to -1. | |
| # Words after the max number of images allowed have been seen don't attend on anything | |
| if num_classes != -1: | |
| incremental_mask[incremental_mask >= num_classes] = -1 | |
| negatives = incremental_mask == -1 | |
| incremental_mask[negatives] = 0 | |
| attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes) | |
| attn_mask[negatives, :] = 0 | |
| return attn_mask | |
| def image_attention_mask_for_packed_input_ids(input_ids, tokenizer): | |
| image_attention_mask = torch.full_like(input_ids, fill_value=-1) | |
| next_image_attention_mask = torch.full_like(input_ids, fill_value=-1) | |
| image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) | |
| eod_token_id = tokenizer.eos_token_id | |
| for batch_idx in range(input_ids.size(0)): | |
| count = -1 | |
| seen_eod = False | |
| for idx, token_id in enumerate(input_ids[batch_idx]): | |
| if token_id == image_token_id: | |
| count += 1 | |
| image_attention_mask[batch_idx][idx] = count | |
| seen_eod = False | |
| else: | |
| image_attention_mask[batch_idx][idx] = count | |
| if seen_eod: | |
| image_attention_mask[batch_idx][idx] = -1 | |
| if token_id == eod_token_id: | |
| seen_eod = True | |
| for batch_idx in range(input_ids.size(0)): | |
| count = -1 | |
| seen_eod = False | |
| for idx in range(input_ids[batch_idx].size(0) - 1, -1, -1): | |
| token_id = input_ids[batch_idx][idx] | |
| if token_id == image_token_id: | |
| count += 1 | |
| next_image_attention_mask[batch_idx][idx] = count | |
| seen_eod = False | |
| else: | |
| next_image_attention_mask[batch_idx][idx] = count | |
| if token_id == eod_token_id: | |
| seen_eod = True | |
| if seen_eod: | |
| next_image_attention_mask[batch_idx][idx] = -1 | |
| non_negative_indices = next_image_attention_mask[batch_idx] != -1 | |
| next_image_attention_mask[batch_idx][non_negative_indices] -= count | |
| next_image_attention_mask[batch_idx][non_negative_indices] *= -1 | |
| return image_attention_mask, next_image_attention_mask | |
| def laplacian_blur_detection(image, threshold=0.0): | |
| # compute the Laplacian of the image and then return the focus | |
| # measure, which is simply the variance of the Laplacian | |
| if threshold == 0.0: | |
| return False | |
| image = np.array(image) | |
| if len(image.shape) == 3 and image.shape[2] == 3: | |
| gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
| return cv2.Laplacian(gray, cv2.CV_64F).var() < threshold | |
| else: | |
| # Don't remove grayscale images | |
| return False | |
| def fft_blur_detection(image, size=50, threshold=0.0): | |
| if threshold == 0.0: | |
| return False | |
| (h, w) = image.shape | |
| (cX, cY) = (int(w / 2.0), int(h / 2.0)) | |
| fft = np.fft.fft2(image) | |
| fftShift = np.fft.fftshift(fft) | |
| fftShift[cY - size : cY + size, cX - size : cX + size] = 0 | |
| fftShift = np.fft.ifftshift(fftShift) | |
| recon = np.fft.ifft2(fftShift) | |
| magnitude = 20 * np.log(np.abs(recon)) | |
| mean = np.mean(magnitude) | |
| return mean < threshold | |
| def split_pack_and_pad( | |
| sample, | |
| tokenizer, | |
| max_seq_len, | |
| image_transform, | |
| max_num_images, | |
| max_num_samples_per_document=10, | |
| prefix_seed=(0, 0), | |
| is_blurred_fn=None, | |
| blur_threshold=0.0, | |
| add_begin_of_doc_token=False, | |
| add_end_of_doc_token=True, | |
| max_num_images_per_document=None, | |
| ): | |
| """ | |
| Return a batch of samples in the format expected by the model which | |
| includes `input_ids`, `pixel_values`, `attention_mask`, `image_attention_mask`, | |
| and `next_image_attention_mask`. The `input_ids` are sampled from the document to | |
| ensure it has `max_seq_len` tokens otherwise, the shorter documents are packed together. | |
| For each document, we sample a maximum of `max_num_samples_per_document` or `max_num_samples_for_curr_document` | |
| (where the latter is proportional to the length of the document and inversely proportional to the length of subsequences) | |
| `input_ids` with sequence length `max_seq_len` from the document. This means that | |
| each sample sampled can have different start index. Based on the start index of sample that | |
| has been sampled, we also sample a maximum of `max_num_images` images from the document. | |
| If there are less than `max_num_images` images in the document, we pad the images with zeros. | |
| The start indexes are skewed towards subsequences that contain images. | |
| Args: | |
| sample (Dict): A sample object containing the document with images and text. | |
| tokenizer (PretrainedTokenizer): Text tokenizer to be used. | |
| max_seq_len (int): Maximum sequence length of the returned text tokens. | |
| image_transform (Callable): Transform to be applied on the images | |
| max_num_images (int): Maximum number of images to be sampled per sample. If less, they are padded with zeros. | |
| max_num_samples_per_document (int, optional): Maximum number of samples per document to be sampled. Defaults to 10. | |
| prefix_seed: Prefix seed sequence for "reproducible randomness" in calls to `np.random.choice` | |
| Returns: | |
| _type_: _description_ | |
| """ | |
| text_batch = sample["texts"] | |
| image_batch = sample.get("image_embeddings", None) | |
| is_raw_images = False | |
| if image_batch is None: | |
| image_batch = sample.get("images", None) | |
| is_raw_images = True | |
| if image_batch is None: | |
| raise ValueError("Either image_embeddings or images must be present in the sample") | |
| image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) | |
| last_was_image = False | |
| if is_blurred_fn is None: | |
| is_blurred_fn = fft_blur_detection | |
| all_images = [] | |
| all_texts = [] | |
| for raw_images, raw_texts in zip(image_batch, text_batch): | |
| # Filter ones that don't have either one image and one text word | |
| if not any(raw_images) or not any(raw_texts): | |
| continue | |
| if max_num_images_per_document: | |
| num_images = sum([1 if image is not None else 0 for image in raw_images]) | |
| if num_images > max_num_images_per_document: | |
| continue | |
| any_blurred = False | |
| if is_raw_images and blur_threshold > 0.0: | |
| for image in raw_images: | |
| if image is not None: | |
| image = _convert_to_rgb(image) | |
| any_blurred = any_blurred or is_blurred_fn(image, threshold=blur_threshold) | |
| if any_blurred: | |
| break | |
| if any_blurred: | |
| continue | |
| inds_of_texts_to_split = [ | |
| i | |
| for i, text in enumerate(raw_texts) | |
| if text is not None and isinstance(text, str) and "END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED" in text | |
| ] | |
| if inds_of_texts_to_split: | |
| splitted_raw_images, splitted_raw_texts = [], [] | |
| previous_i = 0 | |
| for i in inds_of_texts_to_split: | |
| splitting = raw_texts[i].split("END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED") | |
| part1, part2 = splitting[0], splitting[-1] | |
| sub_doc_images = raw_images[previous_i:i] + [None] | |
| sub_doc_texts = raw_texts[previous_i:i] + [part1.strip()] | |
| if not any(sub_doc_images): # This can happen if all images in raw_images[0:i] are all None | |
| continue | |
| splitted_raw_images.append(sub_doc_images) | |
| splitted_raw_texts.append(sub_doc_texts) | |
| if part2.strip() == "": | |
| previous_i = i + 1 | |
| else: | |
| raw_texts[i] = part2.strip() | |
| previous_i = i | |
| if previous_i < len(raw_images) and any(raw_images[previous_i:]): | |
| splitted_raw_images.append(raw_images[previous_i:]) | |
| splitted_raw_texts.append(raw_texts[previous_i:]) | |
| else: | |
| splitted_raw_images, splitted_raw_texts = [raw_images], [raw_texts] | |
| # Sanity check | |
| if [len(ims) for ims in splitted_raw_images] != [len(txts) for txts in splitted_raw_texts]: | |
| raise ValueError( | |
| "Number of images and texts don't match after splitting on `END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED`." | |
| " Something core went wrong during the splitting and needs to be fixed." | |
| ) | |
| for s_r_ims, s_r_txts in zip(splitted_raw_images, splitted_raw_texts): | |
| images, web_text = [], "" | |
| for image, text in zip(s_r_ims, s_r_txts): | |
| if text is None and image is None: | |
| continue | |
| if image is not None: | |
| web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{IMAGE_TOKEN}" | |
| if is_raw_images: | |
| images.append(image_transform(image)) | |
| else: | |
| images.append(torch.tensor(image)) | |
| last_was_image = True | |
| elif text is not None: | |
| if last_was_image: | |
| web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{text}" | |
| last_was_image = False | |
| else: | |
| web_text += f" {text}" if web_text != "" else text | |
| if last_was_image: | |
| web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}" | |
| web_text = web_text.strip(" ") | |
| # This is mostly a sanity check. Cases like that should not happen at that point. | |
| if web_text == "" or len(images) == 0: | |
| continue | |
| images = torch.stack(images) | |
| all_images.append(images) | |
| web_text_ids = tokenizer.encode(web_text, add_special_tokens=False) | |
| if add_end_of_doc_token: | |
| web_text_ids += [tokenizer.eos_token_id] | |
| if add_begin_of_doc_token: | |
| web_text_ids = [tokenizer.bos_token_id] + web_text_ids | |
| all_texts.append(web_text_ids) | |
| output_input_ids = [] | |
| output_images = [] | |
| output_attention_masks = [] | |
| output_num_images = [] | |
| output_num_text_tokens = [] | |
| input_ids_to_pack = [] | |
| images_to_pack = [] | |
| for images, text in zip(all_images, all_texts): | |
| # We save all the documents which are shorter than the max_seq_len to pack them together. | |
| if len(text) <= max_seq_len: | |
| if len(text) < _MIN_LENGTH_DOCUMENTS_TO_PACK: # Filter out extremely short sequences | |
| continue | |
| input_ids_to_pack.extend(text) | |
| images_to_pack.extend(images) | |
| else: | |
| # Computing the bonus scores for tokens near images to skew the sampling towards them | |
| # The main idea is to give a bonus to tokens that are closely before an image token, so that these tokens have more chance to be sampled. | |
| # Bonuses are computed for each image, which means a given token can receive bonuses from multiple images if this token is closely preceding multiple images. | |
| # We sum all the bonuses and L1 normalized along the seq_len axis to get a probability distribution. | |
| # Each token start with a regular bonus of 1, which corresponds to the uniform distribution over the sequence when there are no bonuses added. | |
| # Now the remaining question is which precedding tokens do we distribue bonuses to. | |
| # We first observe that for the sampled sub-sequence to be considered valid (i.e. sub-sequence contains an image), the start index can only be among [image_idx - max_seq_len + 1, image_idx]. | |
| # For the sake of the explanation, let's split the [image_idx - max_seq_len + 1, image_idx] interval in 3 parts: left, middle and right (in increasing order). | |
| # If we give bonuses to the tokens just before the image (right part), then we are favoring p_next=0 because only the tokens after the image have an image to attend to. | |
| # In practice, images will tend to be at the beginning of the sampled sub-sequence. | |
| # If we give bonuses very far before the image (left part), then we are favoring p_next=1 because only the tokens before the image gave an image to attend to. | |
| # In practice, images will tend to be at the end of the sampled sub-sequence. | |
| # To avoid choosing favoring p_next=0 or p_next=1, we can give bonuses to the tokens in the middle part. | |
| # In practise, images will tend to be in the middle of the sampled sequence. | |
| # Ultimately, we don't want to skew the distribution fed to model in that way (i.e. whether images are in the beginning, middle or end of the sampled sub-sequence), | |
| # and have all these cases represented equally in the data. So the easiest is to distribute a bonus to all of the max_seq_len tokens preceding the image. | |
| all_scores = np.array([1] * len(text)) | |
| for img_token_idx in np.where(np.array(text) == image_token_id)[0]: | |
| all_scores[max(0, img_token_idx - max_seq_len) : img_token_idx + 1] += _IMAGE_BONUS_VALUE | |
| # all_scores = np.clip(all_scores, a_min=1, a_max=3 * _IMAGE_BONUS_VALUE * max_num_images + 1) # We can optionally clip the bonuses to avoid having too high values (i.e. outliers documents) | |
| all_scores = all_scores[:-_MIN_LENGTH_DOCUMENTS_TO_PACK] | |
| # The number of samples is proportional to the length of the text and inversely proportional to the maximum sequence length | |
| max_num_samples_for_curr_document = len(text) // max_seq_len | |
| # Set "reproducible randomness" by creating an np.default_rng seeded by (main seed, epoch, rank_idx, worker_idx, mapped_batch_index, text len) | |
| choices = np.random.default_rng(seed=list(prefix_seed) + [len(text)]).choice( | |
| range(len(text) - _MIN_LENGTH_DOCUMENTS_TO_PACK), # shorter sub-sequences are reserved for packing | |
| min( | |
| len(text) - max_seq_len, 2 * max_num_samples_per_document | |
| ), # Sampling more than necessary and then breaking out of the for loop once we have enough samples | |
| p=all_scores / np.linalg.norm(all_scores, ord=1), | |
| replace=False, | |
| ) | |
| nb_effective_sequences_out_of_sampling = 0 | |
| for start_index in choices: | |
| image_start_index = text[:start_index].count(image_token_id) | |
| text_sub_sequence = text[start_index : start_index + max_seq_len] | |
| image_count = text_sub_sequence.count(image_token_id) | |
| if image_count == 0: | |
| # Skip if there are no images in the sequence | |
| continue | |
| if len(text_sub_sequence) < max_seq_len: | |
| # If the sub-sequence is shorter than max_seq_len, we reserve it for packing | |
| # It necessarily mean that the sub-sequence was sampled towards the end of the document, | |
| # which implies that we only need the `image_start_index` and not the `image_end_index` | |
| if text_sub_sequence.count(image_token_id) != len(images[image_start_index:]): | |
| # A safeguard for this | |
| logger.warning( | |
| "Skipping this sample because of mismatch in actual number of images and " | |
| "the '<image>' tokens in the text" | |
| ) | |
| continue | |
| input_ids_to_pack.extend(text_sub_sequence) | |
| images_to_pack.extend(images[image_start_index:]) | |
| continue | |
| current_images = images[image_start_index : image_start_index + min(max_num_images, image_count)] | |
| if len(current_images) != min(max_num_images, image_count): | |
| # A safeguard for something off about this document, maybe `<image>` tag that | |
| # by there from before or some issue in parsing the image? | |
| logger.warning( | |
| "Skipping this sample because of mismatch in actual number of images and " | |
| "the '<image>' tokens in the text" | |
| ) | |
| break | |
| padded_image_tensor = torch.zeros(max_num_images, *images.size()[1:]) | |
| padded_image_tensor[: min(max_num_images, image_count)] = current_images | |
| output_images.append(padded_image_tensor) | |
| output_num_images.append(min(max_num_images, image_count)) | |
| output_input_ids.append(torch.tensor(text_sub_sequence)) | |
| output_num_text_tokens.append(len(text_sub_sequence)) | |
| attention_mask = torch.ones((max_seq_len,), dtype=torch.long) | |
| output_attention_masks.append(attention_mask) | |
| nb_effective_sequences_out_of_sampling += 1 | |
| if nb_effective_sequences_out_of_sampling >= min( | |
| max_num_samples_for_curr_document, max_num_samples_per_document | |
| ): | |
| # We got all the samples we need for this document, so breaking out | |
| break | |
| # Pack the remaining sequences from `input_ids_to_pack` x `images_to_pack` | |
| if input_ids_to_pack: | |
| image_counter = 0 | |
| for i in range(0, len(input_ids_to_pack), max_seq_len): | |
| current_input_ids = input_ids_to_pack[i : i + max_seq_len] | |
| unpadded_seq_len = len(current_input_ids) | |
| num_images = current_input_ids.count(image_token_id) | |
| if num_images == 0: | |
| continue | |
| current_images = images_to_pack[image_counter : image_counter + num_images] | |
| image_counter += num_images | |
| if unpadded_seq_len < max_seq_len: | |
| padded_input_ids = [tokenizer.pad_token_id] * max_seq_len | |
| padded_input_ids[:unpadded_seq_len] = current_input_ids | |
| current_input_ids = padded_input_ids | |
| elif unpadded_seq_len > max_seq_len: | |
| # This case has no purpose other than safeguard | |
| continue | |
| try: | |
| current_images = torch.stack(current_images)[:max_num_images] | |
| except Exception: | |
| continue | |
| padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:]) | |
| padded_image_tensor[: current_images.size(0)] = current_images | |
| attention_mask = torch.zeros((max_seq_len,), dtype=torch.long) | |
| attention_mask[:unpadded_seq_len] = 1 | |
| output_images.append(padded_image_tensor) | |
| output_input_ids.append(torch.tensor(current_input_ids)) | |
| output_num_text_tokens.append(unpadded_seq_len) | |
| output_num_images.append(min(max_num_images, num_images)) | |
| output_attention_masks.append(attention_mask) | |
| if len(output_images) == 0 or len(output_input_ids) == 0: | |
| result = { | |
| "input_ids": torch.tensor([], dtype=torch.long), | |
| "attention_mask": torch.tensor([], dtype=torch.bool), | |
| "image_attention_mask": torch.tensor([], dtype=torch.bool), | |
| "next_image_attention_mask": torch.tensor([], dtype=torch.bool), | |
| "num_images": torch.tensor([], dtype=torch.long), | |
| "num_text_tokens": torch.tensor([], dtype=torch.long), | |
| } | |
| if is_raw_images: | |
| result["pixel_values"] = torch.tensor([], dtype=torch.float32) | |
| else: | |
| result["image_embeddings"] = torch.tensor([], dtype=torch.float32) | |
| return result | |
| output_input_ids = torch.stack(output_input_ids) | |
| output_images = torch.stack(output_images) | |
| output_attention_masks = torch.stack(output_attention_masks) | |
| image_attention_mask, next_image_attention_mask = image_attention_mask_for_packed_input_ids( | |
| output_input_ids, tokenizer | |
| ) | |
| image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=max_num_images) | |
| next_image_attention_mask = incremental_to_binary_attention_mask( | |
| next_image_attention_mask, num_classes=max_num_images | |
| ) | |
| result = { | |
| "input_ids": output_input_ids, | |
| "attention_mask": output_attention_masks, | |
| "image_attention_mask": image_attention_mask, | |
| "next_image_attention_mask": next_image_attention_mask, | |
| "num_images": torch.tensor(output_num_images), | |
| "num_text_tokens": torch.tensor(output_num_text_tokens), | |
| } | |
| if is_raw_images: | |
| result["pixel_values"] = output_images | |
| else: | |
| result["image_embeddings"] = output_images | |
| return result | |
| def split_and_pad_pmd( | |
| sample, | |
| tokenizer, | |
| max_seq_len, | |
| image_transform, | |
| max_num_images, | |
| prefix_seed=(0, 0), | |
| is_blurred_fn=None, | |
| blur_threshold=0.0, | |
| prob_image_at_end=0.5, # If 1, the <image> token is always added at the end of the text | |
| # If set to -1, all padding will be tolerated. If set to 0, no padding will be tolerated. | |
| padding_tolerance=-1, | |
| add_begin_of_doc_token=False, | |
| add_end_of_doc_token=True, | |
| ): | |
| if is_blurred_fn is None: | |
| is_blurred_fn = fft_blur_detection | |
| text_batch = sample["text"] | |
| image_batch = sample.get("image_embedding", None) | |
| is_raw_images = False | |
| if image_batch is None: | |
| image_batch = sample.get("image", None) | |
| is_raw_images = True | |
| filtered_image_batch = [] | |
| filtered_input_ids = [] | |
| # Define whether for the current PMD batch whether the images will be at the start or at the end. | |
| rng = np.random.default_rng(seed=list(prefix_seed)) | |
| is_image_at_end = False | |
| # rng.random is between 0 and 1, so if prob_image_at_end is 1, random value will | |
| # always be less than `prob_image_at_end` and `is_image_at_end` will always be True. | |
| # This means that images will always be at the end of the text. | |
| if rng.random() < prob_image_at_end: | |
| is_image_at_end = True | |
| for image, text in zip(image_batch, text_batch): | |
| if text is None or image is None: | |
| continue | |
| if is_raw_images and is_blurred_fn(image, threshold=blur_threshold): | |
| continue | |
| sample_text = f"{FAKE_TOKEN_AROUND_IMAGE_V2}{IMAGE_TOKEN}{FAKE_TOKEN_AROUND_IMAGE_V2}" | |
| # Remove trailing and leading whitespaces, including newlines and tabs | |
| text = text.strip() | |
| if is_image_at_end: | |
| sample_text = f"{text}{sample_text}" | |
| else: | |
| sample_text = f"{sample_text}{text}" | |
| sample_input_ids = tokenizer.encode(sample_text, add_special_tokens=False) | |
| if add_end_of_doc_token: | |
| sample_input_ids += [tokenizer.eos_token_id] | |
| if add_begin_of_doc_token: | |
| sample_input_ids = [tokenizer.bos_token_id] + sample_input_ids | |
| filtered_image_batch.append(image) | |
| filtered_input_ids.append(sample_input_ids) | |
| # sort by length of text and save same length elements in a mapping so we | |
| # can retrieve candidates later. | |
| filtered_image_batch, filtered_input_ids = zip( | |
| *sorted(zip(filtered_image_batch, filtered_input_ids), key=lambda x: len(x[1])) | |
| ) | |
| mapping_by_len = OrderedDict() | |
| for i, sample_input_ids in enumerate(filtered_input_ids): | |
| if len(sample_input_ids) not in mapping_by_len: | |
| mapping_by_len[len(sample_input_ids)] = [] | |
| mapping_by_len[len(sample_input_ids)].append((filtered_image_batch[i], sample_input_ids)) | |
| all_images = [] | |
| all_texts = [] | |
| all_attention_masks = [] | |
| all_num_images = [] | |
| all_num_text_tokens = [] | |
| current_text = [] | |
| current_images = [] | |
| while True: | |
| current_lens = list(mapping_by_len.keys()) | |
| if len(current_text) > 0: | |
| # Now we try to do a binary search to find the biggest sequence that | |
| # we can fit into the current sequence. | |
| # This will eventually use up bigger sequences faster which is good | |
| # and leave smaller sequences to pack with each other later. | |
| diff = max_seq_len - len(current_text) | |
| if len(current_lens) == 0: | |
| possible_index = -1 | |
| else: | |
| possible_index = bisect_left(current_lens, diff) | |
| if possible_index == len(current_lens) or current_lens[possible_index] != diff: | |
| possible_index -= 1 | |
| if possible_index >= 0: | |
| best_possible_length = current_lens[possible_index] | |
| image, sample_input_ids = mapping_by_len[best_possible_length].pop(0) | |
| # If we have used up all the samples of a certain length, remove | |
| # that length from the mapping. | |
| if len(mapping_by_len[best_possible_length]) == 0: | |
| del mapping_by_len[best_possible_length] | |
| current_text.extend(sample_input_ids) | |
| if is_raw_images: | |
| current_images.append(image_transform(image)) | |
| else: | |
| current_images.append(torch.tensor(image)) | |
| elif diff > padding_tolerance and padding_tolerance != -1: | |
| # If we are here, it means that we still have padding left | |
| # and we have exhausted our current unique options that will allow us to | |
| # fill this sequence completely. | |
| # So, we will try to fill the sequence with whatever we get from the unchanged | |
| # copy of all sequences. | |
| while diff > padding_tolerance: | |
| # Find a random sequence to fit | |
| # Why we need to add more stuff to prefix seed? | |
| # prefix_seed will be same in the same batch which means that it might sample | |
| # same thing again and again if there are multiple cases of padding in the | |
| # same batch which means we need to make this part as random as possible. | |
| rng = np.random.default_rng( | |
| prefix_seed | |
| + ( | |
| diff, | |
| len(current_text), | |
| len(all_texts), | |
| all_num_images, | |
| ) | |
| ) | |
| choice = rng.choice(range(len(filtered_input_ids))) | |
| image, sample_input_ids = filtered_image_batch[choice], filtered_input_ids[choice] | |
| current_text.extend(sample_input_ids) | |
| if is_raw_images: | |
| current_images.append(image_transform(image)) | |
| else: | |
| current_images.append(torch.tensor(image)) | |
| diff = max_seq_len - len(current_text) | |
| # In the next top-level while loop iteration, this should go into the else | |
| # clause which should also handle the sequences longer than max_seq_len | |
| else: | |
| current_images = torch.stack(current_images) | |
| padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:]) | |
| padded_image_tensor[: current_images.size(0)] = current_images[ | |
| : min(max_num_images, current_images.size(0)) | |
| ] | |
| all_num_images.append(min(max_num_images, current_images.size(0))) | |
| all_images.append(padded_image_tensor) | |
| padded_input_ids = torch.full((max_seq_len,), tokenizer.pad_token_id) | |
| current_max_len = min(max_seq_len, len(current_text)) | |
| padded_input_ids[:current_max_len] = torch.tensor(current_text)[:current_max_len] | |
| all_num_text_tokens.append(current_max_len) | |
| all_texts.append(padded_input_ids) | |
| attention_mask = torch.zeros((max_seq_len,), dtype=torch.long) | |
| attention_mask[: len(current_text)] = 1 | |
| all_attention_masks.append(attention_mask) | |
| # Make sure to reset the current text and images. | |
| current_images = [] | |
| current_text = [] | |
| if len(current_lens) == 0: | |
| break | |
| else: | |
| # A case where we might not have any samples left over after the initial filtering step. | |
| if len(current_lens) == 0: | |
| break | |
| image, sample_input_ids = mapping_by_len[current_lens[-1]].pop(0) | |
| if len(mapping_by_len[current_lens[-1]]) == 0: | |
| del mapping_by_len[current_lens[-1]] | |
| current_text = sample_input_ids[:max_seq_len] | |
| if is_raw_images: | |
| current_images = [image_transform(image)] | |
| else: | |
| current_images = [torch.tensor(image)] | |
| if len(all_images) == 0 or len(all_texts) == 0: | |
| result = { | |
| "input_ids": torch.tensor([], dtype=torch.long), | |
| "attention_mask": torch.tensor([], dtype=torch.bool), | |
| "image_attention_mask": torch.tensor([], dtype=torch.bool), | |
| "num_images": torch.tensor([], dtype=torch.long), | |
| "num_text_tokens": torch.tensor([], dtype=torch.long), | |
| } | |
| if is_raw_images: | |
| result["pixel_values"] = torch.tensor([], dtype=torch.float32) | |
| else: | |
| result["image_embeddings"] = torch.tensor([], dtype=torch.float32) | |
| return result | |
| all_texts = torch.stack(all_texts) | |
| all_images = torch.stack(all_images) | |
| all_attention_masks = torch.stack(all_attention_masks) | |
| image_attention_mask, next_image_attention_mask = image_attention_mask_for_packed_input_ids(all_texts, tokenizer) | |
| image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=max_num_images) | |
| next_image_attention_mask = incremental_to_binary_attention_mask( | |
| next_image_attention_mask, num_classes=max_num_images | |
| ) | |
| output = { | |
| "input_ids": all_texts, | |
| "attention_mask": all_attention_masks, | |
| "image_attention_mask": image_attention_mask, | |
| "num_images": torch.tensor(all_num_images), | |
| "num_text_tokens": torch.tensor(all_num_text_tokens), | |
| } | |
| if is_raw_images: | |
| output["pixel_values"] = all_images | |
| else: | |
| output["image_embeddings"] = all_images | |
| if is_image_at_end: | |
| # Set the correct attention mask based on whether the image is at the start | |
| # or not. When it is at the end, we need next image attention mask. | |
| output["image_attention_mask"] = next_image_attention_mask | |
| return output | |
| # Copied from https://github.com/google-research/text-to-text-transfer-transformer/blob/main/t5/data/preprocessors.py | |
| def random_spans_helper( | |
| inputs_length, | |
| noise_density, | |
| mean_noise_span_length, | |
| extra_tokens_per_span_inputs, | |
| extra_tokens_per_span_targets, | |
| verbose=False, | |
| ): | |
| """Training parameters to avoid padding with random_spans_noise_mask. | |
| When training a model with random_spans_noise_mask, we would like to set the | |
| other training hyperparmeters in a way that avoids padding. This function | |
| helps us compute these hyperparameters. | |
| We assume that each noise span in the input is replaced by | |
| extra_tokens_per_span_inputs sentinel tokens, and each non-noise span in the | |
| targets is replaced by extra_tokens_per_span_targets sentinel tokens. | |
| This function tells us the required number of tokens in the raw example (for | |
| split_tokens()) as well as the length of the encoded targets. | |
| Note that this function assumes the inputs and targets will have EOS appended | |
| and includes that in the reported length. | |
| Args: | |
| inputs_length: an integer - desired length of the tokenized inputs sequence | |
| noise_density: a float | |
| mean_noise_span_length: a float | |
| extra_tokens_per_span_inputs: an integer | |
| extra_tokens_per_span_targets: an integer | |
| verbose: a bool indicating whether to log sequence lengths | |
| Returns: | |
| tokens_length: length of original text in tokens | |
| targets_length: an integer - length in tokens of encoded targets sequence | |
| """ | |
| if extra_tokens_per_span_inputs != 1: | |
| raise NotImplementedError( | |
| "extra_tokens_per_span_inputs != 1 not supported yet. You need to check" | |
| " `get_model_tflops_per_batch_per_gpu` of `VT5ForConditionalGeneration` if you change it." | |
| ) | |
| if extra_tokens_per_span_targets != 1: | |
| raise NotImplementedError( | |
| "extra_tokens_per_span_targets != 1 not supported yet. You need to check" | |
| " `get_model_tflops_per_batch_per_gpu` of `VT5ForConditionalGeneration` if you change it." | |
| ) | |
| def _tokens_length_to_inputs_length_targets_length(tokens_length): | |
| num_noise_tokens = int(round(tokens_length * noise_density)) | |
| num_nonnoise_tokens = tokens_length - num_noise_tokens | |
| num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length)) | |
| # inputs contain all nonnoise tokens, sentinels for all noise spans | |
| # and one EOS token. | |
| return ( | |
| num_nonnoise_tokens + num_noise_spans * extra_tokens_per_span_inputs + 1, | |
| num_noise_tokens + num_noise_spans * extra_tokens_per_span_targets + 1, | |
| ) | |
| tokens_length = inputs_length - 1 | |
| while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length: | |
| tokens_length += 1 | |
| inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length) | |
| # minor hack to get the targets length to be equal to inputs length | |
| # which is more likely to have been set to a nice round number. | |
| if noise_density == 0.5 and targets_length > inputs_length: | |
| tokens_length -= 1 | |
| targets_length -= 1 | |
| if verbose: | |
| logging.info( | |
| "tokens_length=%s inputs_length=%s targets_length=%s noise_density=%s mean_noise_span_length=%s ", | |
| tokens_length, | |
| inputs_length, | |
| targets_length, | |
| noise_density, | |
| mean_noise_span_length, | |
| ) | |
| return tokens_length, targets_length | |