| |
|
|
| import re |
| from dataclasses import dataclass |
| from logging import getLogger |
| from typing import Any, Dict, List, Optional |
|
|
| import torch |
|
|
| from core.distributed import get_is_master |
|
|
| logger = getLogger() |
|
|
|
|
| @dataclass |
| class MLLMBatch: |
| x: torch.LongTensor |
| y: torch.LongTensor |
| mask: Optional[torch.BoolTensor] = None |
| image_pos_index: Optional[torch.LongTensor] = None |
| images: Optional[torch.Tensor] = None |
| media_type: Optional[List[str]] = (["text"],) |
| num_image_chunks: Optional[List[int]] = None |
|
|
| def __post_init__(self): |
| assert self.x.dim() == 2, "{} != 2".format(self.x.dim()) |
| assert self.x.shape == self.y.shape |
| assert self.x.dtype == torch.int64 |
| assert self.y.dtype == torch.int64 |
| assert self.mask is None or self.mask.shape == self.x.shape |
|
|
|
|
| class BaseCollator: |
| def __init__( |
| self, |
| tokenizer, |
| show_first_batch: bool = False, |
| ) -> None: |
| self.tokenizer = tokenizer |
| self.first_batch = show_first_batch |
|
|
| def __call__(self, features: List[Dict[str, Any]]): |
| raise NotImplementedError |
|
|
|
|
| class MllmPaddingCollator(BaseCollator): |
|
|
| def prettify_decoded_text(self, texts: List[str]) -> List[str]: |
| """ |
| Prettify the decoded text by replacing consecutive <|image|> tokens with a shortened form using regex. |
| """ |
| prettified = [] |
| special_tokens = ["<|end_of_text|>", "<|image|>"] |
| for text in texts: |
| for token in special_tokens: |
| |
| pattern = f"({re.escape(token)})\\1+" |
|
|
| def replace_consecutive(match): |
| count = len(match.group(0)) // len(token) |
| return f"{token}..x{count}" |
|
|
| text = re.sub(pattern, replace_consecutive, text) |
| prettified.append(text) |
| return prettified |
|
|
| def __call__(self, features: List[Dict[str, Any]]) -> MLLMBatch: |
| text = [] |
| images = [] |
| media_type = [] |
| response_pos = [] |
| image_pos = [] |
| num_image_chunks = [] |
| for b in features: |
| text.append(b["text_ids"]) |
| images.append(b["media"]) |
| response_pos.append(b["response_pos"]) |
| image_pos.append(b["image_pos"]) |
| num_image_chunks.append(b["num_image_chunks"]) |
| media_type.append(b["media_type"]) |
|
|
| images = [img for img in images if img is not None] |
| images = torch.cat(images) if images else None |
|
|
| |
| bsz = len(text) |
| input_ids = torch.full( |
| (bsz, self.tokenizer.seq_len), self.tokenizer.pad_token_id |
| ) |
| label_ids = torch.full( |
| (bsz, self.tokenizer.seq_len), self.tokenizer.pad_token_id |
| ) |
| image_pos_index = torch.full((bsz, self.tokenizer.seq_len), -1) |
|
|
| for i in range(bsz): |
| |
| for j in response_pos[i]: |
| label_ids[i][j - 1] = text[i][j] |
| |
| text_len = len(text[i]) - 1 |
| input_ids[i][:text_len] = torch.tensor(text[i][:-1]) |
| |
| if image_pos[i]: |
| image_indices = torch.arange(len(image_pos[i])) |
| image_pos_index[i, image_pos[i]] = image_indices |
|
|
| mask = label_ids.ne(self.tokenizer.pad_token_id) |
|
|
| |
| input_ids[input_ids == self.tokenizer.pad_token_id] = ( |
| self.tokenizer.eos_token_id |
| ) |
| label_ids[label_ids == self.tokenizer.pad_token_id] = ( |
| self.tokenizer.eos_token_id |
| ) |
|
|
| if self.first_batch and get_is_master(): |
| input_decoded = self.tokenizer.decode_batch(input_ids) |
| label_decoded = self.tokenizer.decode_batch(label_ids) |
| logger.info(f"Input text: \n{self.prettify_decoded_text(input_decoded)}") |
| logger.info(f"Label text: \n{self.prettify_decoded_text(label_decoded)}") |
| self.first_batch = False |
|
|
| return MLLMBatch( |
| x=input_ids, |
| y=label_ids, |
| mask=mask, |
| image_pos_index=image_pos_index, |
| images=images, |
| media_type=media_type, |
| num_image_chunks=num_image_chunks, |
| ) |
|
|