| |
|
|
| import logging |
| import os |
| from typing import Any, Callable, Dict, List, Optional |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
|
|
| from apps.plm.dataset_conf import DatasetConf |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class VisionPreprocessor: |
| def __init__( |
| self, |
| transform: Optional[Callable], |
| tokenizer: Callable, |
| max_video_frames: Optional[int], |
| dataset_config: DatasetConf, |
| ): |
| self.mllm_tokenizer = tokenizer |
| self.transform = transform |
| self.root_dir = "" |
| if dataset_config.root_dir: |
| self.root_dir = dataset_config.root_dir |
| self.max_video_frames = max_video_frames |
|
|
| def __call__(self, row: Dict[str, Any], rng: np.random.RandomState): |
| try: |
| return self.process(row, rng) |
| except Exception as e: |
| logger.error(f"Error processing row: {e}") |
| return None |
|
|
| def process(self, row: Dict[str, Any], rng: np.random.RandomState): |
| del rng |
| if "conversations" in row: |
| conversations = row["conversations"] |
| else: |
| conversations = self.get_conversation(caption=row["text"], prompt="") |
|
|
| if "bbox" in row: |
| assert ( |
| "width" in row and "height" in row |
| ), f"bbox is present in the annotation, however image width or height is not specified, which is not expected." |
| w, h = row["width"], row["height"] |
| bboxes = row["bbox"] |
| conversations = self.transform["region"](conversations, bboxes, w, h) |
|
|
| media = None |
| media_type = "" |
| if "image" in row: |
| processed_images = [] |
| image_files = row["image"] |
| if isinstance(image_files, str): |
| image_files = [image_files] |
| pil_images = [] |
| for image_file in image_files: |
| if self.root_dir: |
| image_file = os.path.join(self.root_dir, image_file) |
| try: |
| image = Image.open(image_file).convert("RGB") |
| pil_images.append(image) |
| except Exception as e: |
| logger.info( |
| f"loading image failed because of the following error:\n {e}" |
| ) |
| return None |
| if self.transform: |
| if len(pil_images) == 1: |
| transform = self.transform["image"] |
| processed_images, _ = transform(pil_images[0]) |
| else: |
| transform = self.transform["video"] |
| processed_images, _ = transform._process_multiple_images_pil( |
| pil_images |
| ) |
| if len(processed_images.shape) == 3: |
| processed_images = processed_images.unsqueeze(0) |
| media = processed_images |
| media_type = "multi_image" if len(image_files) > 1 else "image" |
| elif "video" in row: |
| video_file = row["video"] |
| start_time = row.get("start_time", None) |
| bbox_map = row.get("bbox_map", None) |
| end_time = row.get("end_time", None) |
| if self.root_dir: |
| video_file = os.path.join(self.root_dir, video_file) |
| video_info = ( |
| video_file, |
| self.max_video_frames, |
| start_time, |
| end_time, |
| bbox_map, |
| ) |
| video, _ = self.transform["video"](video_info) |
| media = video |
| media_type = "video" |
| else: |
| |
| |
| media = torch.ones( |
| 1, 3, self.transform["image"].size, self.transform["image"].size |
| ) |
| media_type = "text" |
|
|
| tokenized_sample = self.mllm_tokenizer( |
| conversations=conversations, media=media, media_type=media_type |
| ) |
| out = ( |
| { |
| "media": media, |
| "text_ids": tokenized_sample.text_ids, |
| "response_pos": tokenized_sample.response_pos, |
| "image_pos": tokenized_sample.image_pos, |
| "num_image_chunks": tokenized_sample.num_media_chunks, |
| "media_type": media_type, |
| } |
| if tokenized_sample.is_valid |
| else None |
| ) |
| return out |
|
|
| def get_conversation( |
| self, caption: str, prompt: str = None |
| ) -> List[Dict[str, str]]: |
| """ |
| Converts plain caption to conversation. |
| |
| Args: |
| caption (str): plain caption |
| |
| Returns: |
| List[Dict[str, str]]: conversation |
| """ |
| conversations = [ |
| {"from": "human", "value": prompt if prompt is not None else ""}, |
| {"from": "assistant", "value": caption}, |
| ] |
| return conversations |
|
|