File size: 5,877 Bytes
21ac790 001b110 21ac790 001b110 21ac790 001b110 21ac790 001b110 21ac790 127b72a 001b110 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
"""
Processor class for Molmo.
"""
from typing import List, Union, Optional
from transformers.utils.constants import OPENAI_CLIP_STD, OPENAI_CLIP_MEAN
try:
from typing import Unpack
except ImportError:
from typing_extensions import Unpack
import numpy as np
import torch
from transformers.image_utils import ImageInput
from transformers.processing_utils import (
TextKwargs,
ProcessingKwargs,
ProcessorMixin,
)
from transformers.tokenization_utils_base import TextInput
from transformers.utils import logging
from transformers import AutoTokenizer
from .image_preprocessing_molmo import MolmoImagesKwargs, make_batched_images, MolmoImageProcessor
logger = logging.get_logger(__name__)
DEFAULT_IMAGE_PATCH_TOKEN = f"<im_patch>"
DEFAULT_IM_START_TOKEN = f"<im_start>"
DEFAULT_IM_END_TOKEN = f"<im_end>"
DEFAULT_IM_COL_TOKEN = f"<im_col>"
IMAGE_PROMPT = "<|image|>"
EXTRA_TOKENS = (DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_COL_TOKEN, IMAGE_PROMPT)
def get_special_token_ids(tokenizer):
ids = tokenizer.encode("".join(EXTRA_TOKENS), add_special_tokens=False)
assert len(ids) == len(EXTRA_TOKENS)
return {k: i for k, i in zip(EXTRA_TOKENS, ids)}
class MolmoTextKwargs(TextKwargs, total=False):
style: Optional[str]
system_prompt: Optional[str]
message_format: Optional[str]
always_start_with_space: Optional[bool]
sequence_length: Optional[int]
class MolmoProcessorKwargs(ProcessingKwargs, total=False):
text_kwargs: MolmoTextKwargs
images_kwargs: MolmoImagesKwargs
_defaults = {
"images_kwargs": {
"max_crops": 12,
"overlap_margins": [4, 4],
"base_image_input_size": [336, 336],
"image_token_length_w": 12,
"image_token_length_h": 12,
"image_patch_size": 14,
"image_padding_mask": True,
},
"text_kwargs": {
"style": "long_caption",
"system_prompt": "none",
"message_format": "role",
"always_start_with_space": True,
"sequence_length": 1536,
"padding": False,
},
}
class MolmoProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(self, image_processor: MolmoImageProcessor = None, tokenizer : AutoTokenizer = None, **kwargs):
# self.image_processor = image_processor
# self.tokenizer = tokenizer
super().__init__(image_processor, tokenizer)
self._special_tokens = None
@property
def special_token_ids(self):
if self._special_tokens is None:
self._special_tokens = get_special_token_ids(self.tokenizer)
return self._special_tokens
def get_tokens_input(self, prompt, message_format, always_start_with_space):
if message_format == "none" or message_format is None:
pass
elif message_format == "role":
prompt = "User: " + prompt + " Assistant:"
else:
raise NotImplementedError(f"Message format {message_format} not implemented")
if always_start_with_space:
prompt = " " + prompt
tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
return tokens
def process(
self,
text: TextInput = None,
images: ImageInput = None,
**kwargs: Unpack[MolmoProcessorKwargs],
):
output_kwargs = self._merge_kwargs(
MolmoProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
tokens = self.get_tokens_input(
text,
output_kwargs["text_kwargs"]["message_format"],
output_kwargs["text_kwargs"]["always_start_with_space"],
)
image_token_id = self.special_token_ids[IMAGE_PROMPT]
if images is not None:
images = make_batched_images(images)
images = [np.array(image).astype(np.uint8) for image in images]
# For now only support inserting images at the start
image_idx = [-1]*len(images)
else:
image_idx = None
sequence_length = output_kwargs["text_kwargs"]["sequence_length"]
image_patch_token_id = self.special_token_ids[DEFAULT_IMAGE_PATCH_TOKEN]
image_col_token_id = self.special_token_ids[DEFAULT_IM_COL_TOKEN]
image_start_token_id = self.special_token_ids[DEFAULT_IM_START_TOKEN]
image_end_token_id = self.special_token_ids[DEFAULT_IM_END_TOKEN]
out = self.image_processor.multimodal_preprocess(
images=images,
image_idx=image_idx,
tokens=np.asarray(tokens).astype(np.int32),
sequence_length=sequence_length,
image_patch_token_id=image_patch_token_id,
image_col_token_id=image_col_token_id,
image_start_token_id=image_start_token_id,
image_end_token_id=image_end_token_id,
**output_kwargs["images_kwargs"]
)
# Prepend BOS
# qwen2 and olmo do not have a BOS, and instead use EOS as a generic seperator token.
bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
decoder_input_tokens = np.pad(out["input_ids"], [[1, 0]], constant_values=bos)
out["input_ids"] = decoder_input_tokens
if "image_input_idx" in out:
# Shift patch mapping up by one since we added BOS
image_input_idx = out["image_input_idx"]
out["image_input_idx"] = np.where(image_input_idx < 0, image_input_idx, image_input_idx + 1)
for k, v in out.items():
out[k] = torch.from_numpy(v)
return out
MolmoProcessor.register_for_auto_class()
|